Skip to main content

sage_codegen/
generator.rs

1//! Main code generator.
2
3use crate::emit::Emitter;
4use sage_loader::ModuleTree;
5use sage_parser::{
6    AgentDecl, BinOp, Block, ConstDecl, EnumDecl, EventKind, Expr, FnDecl, Literal, Program,
7    RecordDecl, Stmt, StringPart, UnaryOp,
8};
9use sage_types::TypeExpr;
10
11/// Generated Rust project files.
12pub struct GeneratedProject {
13    /// The main.rs content.
14    pub main_rs: String,
15    /// The Cargo.toml content.
16    pub cargo_toml: String,
17}
18
19/// Generate Rust code from a Sage program (single file).
20pub fn generate(program: &Program, project_name: &str) -> GeneratedProject {
21    let mut gen = Generator::new();
22    let main_rs = gen.generate_program(program);
23    let cargo_toml = gen.generate_cargo_toml(project_name);
24    GeneratedProject {
25        main_rs,
26        cargo_toml,
27    }
28}
29
30/// Generate Rust code from a module tree (multi-file project).
31///
32/// This flattens all modules into a single Rust file, generating all agents
33/// and functions with appropriate visibility modifiers.
34pub fn generate_module_tree(tree: &ModuleTree, project_name: &str) -> GeneratedProject {
35    let mut gen = Generator::new();
36    let main_rs = gen.generate_module_tree(tree);
37    let cargo_toml = gen.generate_cargo_toml(project_name);
38    GeneratedProject {
39        main_rs,
40        cargo_toml,
41    }
42}
43
44struct Generator {
45    emit: Emitter,
46}
47
48impl Generator {
49    fn new() -> Self {
50        Self {
51            emit: Emitter::new(),
52        }
53    }
54
55    fn generate_program(&mut self, program: &Program) -> String {
56        // Prelude
57        self.emit
58            .writeln("//! Generated by Sage compiler. Do not edit.");
59        self.emit.blank_line();
60        self.emit.writeln("use sage_runtime::prelude::*;");
61        self.emit.blank_line();
62
63        // Constants
64        for const_decl in &program.consts {
65            self.generate_const(const_decl);
66            self.emit.blank_line();
67        }
68
69        // Enums
70        for enum_decl in &program.enums {
71            self.generate_enum(enum_decl);
72            self.emit.blank_line();
73        }
74
75        // Records
76        for record in &program.records {
77            self.generate_record(record);
78            self.emit.blank_line();
79        }
80
81        // Functions
82        for func in &program.functions {
83            self.generate_function(func);
84            self.emit.blank_line();
85        }
86
87        // Agents
88        for agent in &program.agents {
89            self.generate_agent(agent);
90            self.emit.blank_line();
91        }
92
93        // Entry point (required for executables)
94        if let Some(run_agent) = &program.run_agent {
95            // RFC-0007: Check if entry agent has an error handler
96            let has_error_handler = program
97                .agents
98                .iter()
99                .find(|a| a.name.name == run_agent.name)
100                .map_or(false, |a| {
101                    a.handlers
102                        .iter()
103                        .any(|h| matches!(h.event, EventKind::Error { .. }))
104                });
105            self.generate_main(&run_agent.name, has_error_handler);
106        }
107
108        std::mem::take(&mut self.emit).finish()
109    }
110
111    fn generate_module_tree(&mut self, tree: &ModuleTree) -> String {
112        // Prelude
113        self.emit
114            .writeln("//! Generated by Sage compiler. Do not edit.");
115        self.emit.blank_line();
116        self.emit.writeln("use sage_runtime::prelude::*;");
117        self.emit.blank_line();
118
119        // Generate all modules, starting with the root
120        // We flatten everything into one file for simplicity
121        // (A more advanced implementation would generate mod.rs files)
122
123        // First, generate non-root modules
124        for (path, module) in &tree.modules {
125            if path != &tree.root {
126                self.emit.write("// Module: ");
127                if path.is_empty() {
128                    self.emit.writeln("(root)");
129                } else {
130                    self.emit.writeln(&path.join("::"));
131                }
132
133                for const_decl in &module.program.consts {
134                    self.generate_const(const_decl);
135                    self.emit.blank_line();
136                }
137
138                for enum_decl in &module.program.enums {
139                    self.generate_enum(enum_decl);
140                    self.emit.blank_line();
141                }
142
143                for record in &module.program.records {
144                    self.generate_record(record);
145                    self.emit.blank_line();
146                }
147
148                for func in &module.program.functions {
149                    self.generate_function(func);
150                    self.emit.blank_line();
151                }
152
153                for agent in &module.program.agents {
154                    self.generate_agent(agent);
155                    self.emit.blank_line();
156                }
157            }
158        }
159
160        // Then, generate the root module
161        if let Some(root_module) = tree.modules.get(&tree.root) {
162            self.emit.writeln("// Root module");
163
164            for const_decl in &root_module.program.consts {
165                self.generate_const(const_decl);
166                self.emit.blank_line();
167            }
168
169            for enum_decl in &root_module.program.enums {
170                self.generate_enum(enum_decl);
171                self.emit.blank_line();
172            }
173
174            for record in &root_module.program.records {
175                self.generate_record(record);
176                self.emit.blank_line();
177            }
178
179            for func in &root_module.program.functions {
180                self.generate_function(func);
181                self.emit.blank_line();
182            }
183
184            for agent in &root_module.program.agents {
185                self.generate_agent(agent);
186                self.emit.blank_line();
187            }
188
189            // Entry point (only in root module)
190            if let Some(run_agent) = &root_module.program.run_agent {
191                // RFC-0007: Check if entry agent has an error handler
192                let has_error_handler = root_module
193                    .program
194                    .agents
195                    .iter()
196                    .find(|a| a.name.name == run_agent.name)
197                    .map_or(false, |a| {
198                        a.handlers
199                            .iter()
200                            .any(|h| matches!(h.event, EventKind::Error { .. }))
201                    });
202                self.generate_main(&run_agent.name, has_error_handler);
203            }
204        }
205
206        std::mem::take(&mut self.emit).finish()
207    }
208
209    fn generate_cargo_toml(&self, name: &str) -> String {
210        // Use a relative path that works from target/sage/<project>/
211        // This assumes the standard project layout
212        format!(
213            r#"[package]
214name = "{name}"
215version = "0.1.0"
216edition = "2021"
217
218[dependencies]
219sage-runtime = {{ path = "../../../crates/sage-runtime" }}
220tokio = {{ version = "1", features = ["full"] }}
221serde = {{ version = "1", features = ["derive"] }}
222serde_json = "1"
223
224# Standalone project, not part of parent workspace
225[workspace]
226"#
227        )
228    }
229
230    fn generate_const(&mut self, const_decl: &ConstDecl) {
231        if const_decl.is_pub {
232            self.emit.write("pub ");
233        }
234        self.emit.write("const ");
235        self.emit.write(&const_decl.name.name);
236        self.emit.write(": ");
237        self.emit_type(&const_decl.ty);
238        self.emit.write(" = ");
239        self.generate_expr(&const_decl.value);
240        self.emit.writeln(";");
241    }
242
243    fn generate_enum(&mut self, enum_decl: &EnumDecl) {
244        if enum_decl.is_pub {
245            self.emit.write("pub ");
246        }
247        self.emit
248            .writeln("#[derive(Debug, Clone, Copy, PartialEq, Eq)]");
249        self.emit.write("enum ");
250        self.emit.write(&enum_decl.name.name);
251        self.emit.writeln(" {");
252        self.emit.indent();
253        for variant in &enum_decl.variants {
254            self.emit.write(&variant.name);
255            self.emit.writeln(",");
256        }
257        self.emit.dedent();
258        self.emit.writeln("}");
259    }
260
261    fn generate_record(&mut self, record: &RecordDecl) {
262        if record.is_pub {
263            self.emit.write("pub ");
264        }
265        self.emit.writeln("#[derive(Debug, Clone)]");
266        self.emit.write("struct ");
267        self.emit.write(&record.name.name);
268        self.emit.writeln(" {");
269        self.emit.indent();
270        for field in &record.fields {
271            self.emit.write(&field.name.name);
272            self.emit.write(": ");
273            self.emit_type(&field.ty);
274            self.emit.writeln(",");
275        }
276        self.emit.dedent();
277        self.emit.writeln("}");
278    }
279
280    fn generate_function(&mut self, func: &FnDecl) {
281        // Function signature with visibility
282        if func.is_pub {
283            self.emit.write("pub ");
284        }
285        self.emit.write("fn ");
286        self.emit.write(&func.name.name);
287        self.emit.write("(");
288
289        for (i, param) in func.params.iter().enumerate() {
290            if i > 0 {
291                self.emit.write(", ");
292            }
293            self.emit.write(&param.name.name);
294            self.emit.write(": ");
295            self.emit_type(&param.ty);
296        }
297
298        self.emit.write(") -> ");
299
300        // RFC-0007: Wrap return type in SageResult if fallible
301        if func.is_fallible {
302            self.emit.write("SageResult<");
303            self.emit_type(&func.return_ty);
304            self.emit.write(">");
305        } else {
306            self.emit_type(&func.return_ty);
307        }
308
309        self.emit.write(" ");
310        self.generate_block(&func.body);
311    }
312
313    fn generate_agent(&mut self, agent: &AgentDecl) {
314        let name = &agent.name.name;
315
316        // Struct definition with visibility
317        if agent.is_pub {
318            self.emit.write("pub ");
319        }
320        self.emit.write("struct ");
321        self.emit.write(name);
322        if agent.beliefs.is_empty() {
323            self.emit.writeln(";");
324        } else {
325            self.emit.writeln(" {");
326            self.emit.indent();
327            for belief in &agent.beliefs {
328                self.emit.write(&belief.name.name);
329                self.emit.write(": ");
330                self.emit_type(&belief.ty);
331                self.emit.writeln(",");
332            }
333            self.emit.dedent();
334            self.emit.writeln("}");
335        }
336        self.emit.blank_line();
337
338        // Find the output type from the start handler
339        let output_type = self.infer_agent_output_type(agent);
340
341        // Impl block
342        self.emit.write("impl ");
343        self.emit.write(name);
344        self.emit.writeln(" {");
345        self.emit.indent();
346
347        // Generate handlers
348        for handler in &agent.handlers {
349            match &handler.event {
350                EventKind::Start => {
351                    self.emit
352                        .write("async fn on_start(self, ctx: AgentContext<");
353                    self.emit.write(&output_type);
354                    self.emit.write(">) -> SageResult<");
355                    self.emit.write(&output_type);
356                    self.emit.writeln("> {");
357                    self.emit.indent();
358                    self.generate_block_contents(&handler.body);
359                    self.emit.dedent();
360                    self.emit.writeln("}");
361                }
362
363                // RFC-0007: Generate on_error handler
364                EventKind::Error { param_name } => {
365                    self.emit.write("async fn on_error(self, ");
366                    self.emit.write(&param_name.name);
367                    self.emit.write(": SageError, ctx: AgentContext<");
368                    self.emit.write(&output_type);
369                    self.emit.write(">) -> SageResult<");
370                    self.emit.write(&output_type);
371                    self.emit.writeln("> {");
372                    self.emit.indent();
373                    self.generate_block_contents(&handler.body);
374                    self.emit.dedent();
375                    self.emit.writeln("}");
376                }
377
378                // Other handlers (message, stop) - future work
379                _ => {}
380            }
381        }
382
383        self.emit.dedent();
384        self.emit.writeln("}");
385    }
386
387    fn generate_main(&mut self, entry_agent: &str, has_error_handler: bool) {
388        self.emit.writeln("#[tokio::main]");
389        self.emit
390            .writeln("async fn main() -> Result<(), Box<dyn std::error::Error>> {");
391        self.emit.indent();
392
393        if has_error_handler {
394            // RFC-0007: Generate error dispatch code
395            self.emit
396                .writeln("let handle = sage_runtime::spawn(|ctx| async move {");
397            self.emit.indent();
398            self.emit.write("match ");
399            self.emit.write(entry_agent);
400            self.emit.writeln(".on_start(ctx.clone()).await {");
401            self.emit.indent();
402            self.emit.writeln("Ok(result) => Ok(result),");
403            self.emit.write("Err(e) => ");
404            self.emit.write(entry_agent);
405            self.emit.writeln(".on_error(e, ctx).await,");
406            self.emit.dedent();
407            self.emit.writeln("}");
408            self.emit.dedent();
409            self.emit.writeln("});");
410        } else {
411            self.emit.write("let handle = sage_runtime::spawn(|ctx| ");
412            self.emit.write(entry_agent);
413            self.emit.writeln(".on_start(ctx));");
414        }
415
416        self.emit.writeln("let result = handle.result().await?;");
417        self.emit.writeln("println!(\"{:?}\", result);");
418        self.emit.writeln("Ok(())");
419
420        self.emit.dedent();
421        self.emit.writeln("}");
422    }
423
424    fn generate_block(&mut self, block: &Block) {
425        self.emit.open_brace();
426        self.generate_block_contents(block);
427        self.emit.close_brace();
428    }
429
430    fn generate_block_inline(&mut self, block: &Block) {
431        self.emit.open_brace();
432        self.generate_block_contents(block);
433        self.emit.close_brace_inline();
434    }
435
436    fn generate_block_contents(&mut self, block: &Block) {
437        for stmt in &block.stmts {
438            self.generate_stmt(stmt);
439        }
440    }
441
442    fn generate_stmt(&mut self, stmt: &Stmt) {
443        match stmt {
444            Stmt::Let {
445                name, ty, value, ..
446            } => {
447                self.emit.write("let ");
448                if ty.is_some() {
449                    self.emit.write(&name.name);
450                    self.emit.write(": ");
451                    self.emit_type(ty.as_ref().unwrap());
452                } else {
453                    self.emit.write(&name.name);
454                }
455                self.emit.write(" = ");
456                self.generate_expr(value);
457                self.emit.writeln(";");
458            }
459
460            Stmt::Assign { name, value, .. } => {
461                self.emit.write(&name.name);
462                self.emit.write(" = ");
463                self.generate_expr(value);
464                self.emit.writeln(";");
465            }
466
467            Stmt::Return { value, .. } => {
468                self.emit.write("return ");
469                if let Some(expr) = value {
470                    self.generate_expr(expr);
471                }
472                self.emit.writeln(";");
473            }
474
475            Stmt::If {
476                condition,
477                then_block,
478                else_block,
479                ..
480            } => {
481                self.emit.write("if ");
482                self.generate_expr(condition);
483                self.emit.write(" ");
484                if else_block.is_some() {
485                    self.generate_block_inline(then_block);
486                    self.emit.write(" else ");
487                    match else_block.as_ref().unwrap() {
488                        sage_parser::ElseBranch::Block(block) => {
489                            self.generate_block(block);
490                        }
491                        sage_parser::ElseBranch::ElseIf(stmt) => {
492                            self.generate_stmt(stmt);
493                        }
494                    }
495                } else {
496                    self.generate_block(then_block);
497                }
498            }
499
500            Stmt::For {
501                var, iter, body, ..
502            } => {
503                self.emit.write("for ");
504                self.emit.write(&var.name);
505                self.emit.write(" in ");
506                self.generate_expr(iter);
507                self.emit.write(" ");
508                self.generate_block(body);
509            }
510
511            Stmt::While {
512                condition, body, ..
513            } => {
514                self.emit.write("while ");
515                self.generate_expr(condition);
516                self.emit.write(" ");
517                self.generate_block(body);
518            }
519
520            Stmt::Loop { body, .. } => {
521                self.emit.write("loop ");
522                self.generate_block(body);
523            }
524
525            Stmt::Break { .. } => {
526                self.emit.writeln("break;");
527            }
528
529            Stmt::Expr { expr, .. } => {
530                // Handle emit specially
531                if let Expr::Emit { value, .. } = expr {
532                    self.emit.write("return ctx.emit(");
533                    self.generate_expr(value);
534                    self.emit.writeln(");");
535                } else {
536                    self.generate_expr(expr);
537                    self.emit.writeln(";");
538                }
539            }
540        }
541    }
542
543    fn generate_expr(&mut self, expr: &Expr) {
544        match expr {
545            Expr::Literal { value, .. } => {
546                self.emit_literal(value);
547            }
548
549            Expr::Var { name, .. } => {
550                self.emit.write(&name.name);
551            }
552
553            Expr::Binary {
554                op, left, right, ..
555            } => {
556                // Handle string concatenation specially
557                if matches!(op, BinOp::Concat) {
558                    self.emit.write("format!(\"{}{}\", ");
559                    self.generate_expr(left);
560                    self.emit.write(", ");
561                    self.generate_expr(right);
562                    self.emit.write(")");
563                } else {
564                    self.emit.write("(");
565                    self.generate_expr(left);
566                    self.emit.write(" ");
567                    self.emit_binop(op);
568                    self.emit.write(" ");
569                    self.generate_expr(right);
570                    self.emit.write(")");
571                }
572            }
573
574            Expr::Unary { op, operand, .. } => {
575                self.emit_unaryop(op);
576                self.generate_expr(operand);
577            }
578
579            Expr::Call { name, args, .. } => {
580                let fn_name = &name.name;
581
582                // Handle builtins
583                match fn_name.as_str() {
584                    "print" => {
585                        self.emit.write("println!(\"{}\", ");
586                        self.generate_expr(&args[0]);
587                        self.emit.write(")");
588                    }
589                    "str" => {
590                        self.generate_expr(&args[0]);
591                        self.emit.write(".to_string()");
592                    }
593                    "len" => {
594                        self.generate_expr(&args[0]);
595                        self.emit.write(".len() as i64");
596                    }
597                    _ => {
598                        self.emit.write(fn_name);
599                        self.emit.write("(");
600                        for (i, arg) in args.iter().enumerate() {
601                            if i > 0 {
602                                self.emit.write(", ");
603                            }
604                            self.generate_expr(arg);
605                        }
606                        self.emit.write(")");
607                    }
608                }
609            }
610
611            Expr::SelfField { field, .. } => {
612                self.emit.write("self.");
613                self.emit.write(&field.name);
614            }
615
616            Expr::SelfMethodCall { method, args, .. } => {
617                self.emit.write("self.");
618                self.emit.write(&method.name);
619                self.emit.write("(");
620                for (i, arg) in args.iter().enumerate() {
621                    if i > 0 {
622                        self.emit.write(", ");
623                    }
624                    self.generate_expr(arg);
625                }
626                self.emit.write(")");
627            }
628
629            Expr::List { elements, .. } => {
630                self.emit.write("vec![");
631                for (i, elem) in elements.iter().enumerate() {
632                    if i > 0 {
633                        self.emit.write(", ");
634                    }
635                    self.generate_expr(elem);
636                }
637                self.emit.write("]");
638            }
639
640            Expr::Paren { inner, .. } => {
641                self.emit.write("(");
642                self.generate_expr(inner);
643                self.emit.write(")");
644            }
645
646            Expr::Infer { template, .. } => {
647                self.emit.write("ctx.infer_string(&");
648                self.emit_string_template(template);
649                self.emit.write(").await?");
650            }
651
652            Expr::Spawn { agent, fields, .. } => {
653                self.emit.write("sage_runtime::spawn(|ctx| ");
654                self.emit.write(&agent.name);
655                if fields.is_empty() {
656                    self.emit.write(".on_start(ctx))");
657                } else {
658                    self.emit.write(" { ");
659                    for (i, field) in fields.iter().enumerate() {
660                        if i > 0 {
661                            self.emit.write(", ");
662                        }
663                        self.emit.write(&field.name.name);
664                        self.emit.write(": ");
665                        self.generate_expr(&field.value);
666                    }
667                    self.emit.write(" }.on_start(ctx))");
668                }
669            }
670
671            Expr::Await { handle, .. } => {
672                self.generate_expr(handle);
673                self.emit.write(".result().await?");
674            }
675
676            Expr::Send {
677                handle, message, ..
678            } => {
679                self.generate_expr(handle);
680                self.emit.write(".send(sage_runtime::Message::new(");
681                self.generate_expr(message);
682                self.emit.write(")?).await?");
683            }
684
685            Expr::Emit { value, .. } => {
686                self.emit.write("ctx.emit(");
687                self.generate_expr(value);
688                self.emit.write(")");
689            }
690
691            Expr::StringInterp { template, .. } => {
692                self.emit_string_template(template);
693            }
694
695            // TODO: Implement in RFC-0005
696            Expr::Match {
697                scrutinee, arms, ..
698            } => {
699                self.emit.write("match ");
700                self.generate_expr(scrutinee);
701                self.emit.writeln(" {");
702                self.emit.indent();
703                for arm in arms {
704                    self.emit_pattern(&arm.pattern);
705                    self.emit.write(" => ");
706                    self.generate_expr(&arm.body);
707                    self.emit.writeln(",");
708                }
709                self.emit.dedent();
710                self.emit.write("}");
711            }
712
713            // TODO: Implement in RFC-0005
714            Expr::RecordConstruct { name, fields, .. } => {
715                self.emit.write(&name.name);
716                self.emit.write(" { ");
717                for (i, field) in fields.iter().enumerate() {
718                    if i > 0 {
719                        self.emit.write(", ");
720                    }
721                    self.emit.write(&field.name.name);
722                    self.emit.write(": ");
723                    self.generate_expr(&field.value);
724                }
725                self.emit.write(" }");
726            }
727
728            // TODO: Implement in RFC-0005
729            Expr::FieldAccess { object, field, .. } => {
730                self.generate_expr(object);
731                self.emit.write(".");
732                self.emit.write(&field.name);
733            }
734
735            Expr::Receive { .. } => {
736                self.emit.write("ctx.receive().await?");
737            }
738
739            // RFC-0007: Error handling
740            Expr::Try { expr, .. } => {
741                // Generate the inner expression with ? for error propagation
742                self.generate_expr(expr);
743                self.emit.write("?");
744            }
745
746            Expr::Catch {
747                expr,
748                error_bind,
749                recovery,
750                ..
751            } => {
752                // Generate a match expression to handle the Result
753                self.emit.write("match ");
754                self.generate_expr(expr);
755                self.emit.writeln(" {");
756                self.emit.indent();
757
758                // Ok arm - unwrap the value
759                self.emit.writeln("Ok(__val) => __val,");
760
761                // Err arm - run recovery
762                if let Some(err_name) = error_bind {
763                    self.emit.write("Err(");
764                    self.emit.write(&err_name.name);
765                    self.emit.write(") => ");
766                } else {
767                    self.emit.write("Err(_) => ");
768                }
769                self.generate_expr(recovery);
770                self.emit.writeln(",");
771
772                self.emit.dedent();
773                self.emit.write("}");
774            }
775        }
776    }
777
778    fn emit_pattern(&mut self, pattern: &sage_parser::Pattern) {
779        use sage_parser::Pattern;
780        match pattern {
781            Pattern::Wildcard { .. } => {
782                self.emit.write("_");
783            }
784            Pattern::Variant {
785                enum_name, variant, ..
786            } => {
787                if let Some(enum_name) = enum_name {
788                    self.emit.write(&enum_name.name);
789                    self.emit.write("::");
790                }
791                self.emit.write(&variant.name);
792            }
793            Pattern::Literal { value, .. } => {
794                self.emit_literal(value);
795            }
796            Pattern::Binding { name, .. } => {
797                self.emit.write(&name.name);
798            }
799        }
800    }
801
802    fn emit_literal(&mut self, lit: &Literal) {
803        match lit {
804            Literal::Int(n) => {
805                self.emit.write(&format!("{n}_i64"));
806            }
807            Literal::Float(f) => {
808                self.emit.write(&format!("{f}_f64"));
809            }
810            Literal::Bool(b) => {
811                self.emit.write(if *b { "true" } else { "false" });
812            }
813            Literal::String(s) => {
814                // Escape the string for Rust
815                self.emit.write("\"");
816                for c in s.chars() {
817                    match c {
818                        '"' => self.emit.write_raw("\\\""),
819                        '\\' => self.emit.write_raw("\\\\"),
820                        '\n' => self.emit.write_raw("\\n"),
821                        '\r' => self.emit.write_raw("\\r"),
822                        '\t' => self.emit.write_raw("\\t"),
823                        _ => self.emit.write_raw(&c.to_string()),
824                    }
825                }
826                self.emit.write("\".to_string()");
827            }
828        }
829    }
830
831    fn emit_string_template(&mut self, template: &sage_parser::StringTemplate) {
832        if !template.has_interpolations() {
833            // Simple string literal
834            if let Some(StringPart::Literal(s)) = template.parts.first() {
835                self.emit.write("\"");
836                self.emit.write_raw(s);
837                self.emit.write("\".to_string()");
838            }
839            return;
840        }
841
842        // Build format string and args
843        self.emit.write("format!(\"");
844        for part in &template.parts {
845            match part {
846                StringPart::Literal(s) => {
847                    // Escape braces for format string
848                    let escaped = s.replace('{', "{{").replace('}', "}}");
849                    self.emit.write_raw(&escaped);
850                }
851                StringPart::Interpolation(_) => {
852                    self.emit.write_raw("{}");
853                }
854            }
855        }
856        self.emit.write("\"");
857
858        // Add the interpolation args
859        for part in &template.parts {
860            if let StringPart::Interpolation(ident) = part {
861                self.emit.write(", ");
862                self.emit.write(&ident.name);
863            }
864        }
865        self.emit.write(")");
866    }
867
868    fn emit_type(&mut self, ty: &TypeExpr) {
869        match ty {
870            TypeExpr::Int => self.emit.write("i64"),
871            TypeExpr::Float => self.emit.write("f64"),
872            TypeExpr::Bool => self.emit.write("bool"),
873            TypeExpr::String => self.emit.write("String"),
874            TypeExpr::Unit => self.emit.write("()"),
875            TypeExpr::List(inner) => {
876                self.emit.write("Vec<");
877                self.emit_type(inner);
878                self.emit.write(">");
879            }
880            TypeExpr::Option(inner) => {
881                self.emit.write("Option<");
882                self.emit_type(inner);
883                self.emit.write(">");
884            }
885            TypeExpr::Inferred(inner) => {
886                // Inferred<T> just becomes T at runtime
887                self.emit_type(inner);
888            }
889            TypeExpr::Agent(agent_name) => {
890                // Agent handles use the agent's output type, but we don't know it here
891                // For now, just use a generic output type
892                self.emit.write("AgentHandle<");
893                self.emit.write(&agent_name.name);
894                self.emit.write("Output>");
895            }
896            TypeExpr::Named(name) => {
897                self.emit.write(&name.name);
898            }
899
900            // RFC-0007: Error handling
901            TypeExpr::Error => {
902                self.emit.write("sage_runtime::SageError");
903            }
904        }
905    }
906
907    fn emit_binop(&mut self, op: &BinOp) {
908        let s = match op {
909            BinOp::Add => "+",
910            BinOp::Sub => "-",
911            BinOp::Mul => "*",
912            BinOp::Div => "/",
913            BinOp::Eq => "==",
914            BinOp::Ne => "!=",
915            BinOp::Lt => "<",
916            BinOp::Gt => ">",
917            BinOp::Le => "<=",
918            BinOp::Ge => ">=",
919            BinOp::And => "&&",
920            BinOp::Or => "||",
921            BinOp::Concat => "++", // Handled specially above
922        };
923        self.emit.write(s);
924    }
925
926    fn emit_unaryop(&mut self, op: &UnaryOp) {
927        let s = match op {
928            UnaryOp::Neg => "-",
929            UnaryOp::Not => "!",
930        };
931        self.emit.write(s);
932    }
933
934    fn infer_agent_output_type(&self, agent: &AgentDecl) -> String {
935        // Look for emit expression in start handler to infer return type
936        // For now, default to i64
937        for handler in &agent.handlers {
938            if let EventKind::Start = &handler.event {
939                if let Some(ty) = self.find_emit_type(&handler.body) {
940                    return ty;
941                }
942            }
943        }
944        "i64".to_string()
945    }
946
947    fn find_emit_type(&self, block: &Block) -> Option<String> {
948        for stmt in &block.stmts {
949            if let Stmt::Expr { expr, .. } = stmt {
950                if let Expr::Emit { value, .. } = expr {
951                    return Some(self.infer_expr_type(value));
952                }
953            }
954            // Check nested blocks
955            if let Stmt::If {
956                then_block,
957                else_block,
958                ..
959            } = stmt
960            {
961                if let Some(ty) = self.find_emit_type(then_block) {
962                    return Some(ty);
963                }
964                if let Some(else_branch) = else_block {
965                    if let sage_parser::ElseBranch::Block(block) = else_branch {
966                        if let Some(ty) = self.find_emit_type(block) {
967                            return Some(ty);
968                        }
969                    }
970                }
971            }
972        }
973        None
974    }
975
976    fn infer_expr_type(&self, expr: &Expr) -> String {
977        match expr {
978            Expr::Literal { value, .. } => match value {
979                Literal::Int(_) => "i64".to_string(),
980                Literal::Float(_) => "f64".to_string(),
981                Literal::Bool(_) => "bool".to_string(),
982                Literal::String(_) => "String".to_string(),
983            },
984            Expr::Var { .. } => "i64".to_string(), // Conservative default
985            Expr::Binary { op, .. } => {
986                if matches!(
987                    op,
988                    BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge
989                ) {
990                    "bool".to_string()
991                } else if matches!(op, BinOp::Concat) {
992                    "String".to_string()
993                } else {
994                    "i64".to_string()
995                }
996            }
997            Expr::Infer { .. } | Expr::StringInterp { .. } => "String".to_string(),
998            Expr::Call { name, .. } if name.name == "str" => "String".to_string(),
999            Expr::Call { name, .. } if name.name == "len" => "i64".to_string(),
1000            _ => "i64".to_string(),
1001        }
1002    }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007    use super::*;
1008    use sage_lexer::lex;
1009    use sage_parser::parse;
1010    use std::sync::Arc;
1011
1012    fn generate_source(source: &str) -> String {
1013        let lex_result = lex(source).expect("lexing failed");
1014        let source_arc: Arc<str> = Arc::from(source);
1015        let (program, errors) = parse(lex_result.tokens(), source_arc);
1016        assert!(errors.is_empty(), "parse errors: {errors:?}");
1017        let program = program.expect("should parse");
1018        generate(&program, "test").main_rs
1019    }
1020
1021    #[test]
1022    fn generate_minimal_program() {
1023        let source = r#"
1024            agent Main {
1025                on start {
1026                    emit(42);
1027                }
1028            }
1029            run Main;
1030        "#;
1031
1032        let output = generate_source(source);
1033        assert!(output.contains("struct Main;"));
1034        assert!(output.contains("async fn on_start"));
1035        assert!(output.contains("ctx.emit(42_i64)"));
1036        assert!(output.contains("#[tokio::main]"));
1037    }
1038
1039    #[test]
1040    fn generate_function() {
1041        let source = r#"
1042            fn add(a: Int, b: Int) -> Int {
1043                return a + b;
1044            }
1045            agent Main {
1046                on start {
1047                    emit(add(1, 2));
1048                }
1049            }
1050            run Main;
1051        "#;
1052
1053        let output = generate_source(source);
1054        assert!(output.contains("fn add(a: i64, b: i64) -> i64"));
1055        assert!(output.contains("return (a + b);"));
1056    }
1057
1058    #[test]
1059    fn generate_agent_with_beliefs() {
1060        let source = r#"
1061            agent Worker {
1062                value: Int
1063
1064                on start {
1065                    emit(self.value * 2);
1066                }
1067            }
1068            agent Main {
1069                on start {
1070                    emit(0);
1071                }
1072            }
1073            run Main;
1074        "#;
1075
1076        let output = generate_source(source);
1077        assert!(output.contains("struct Worker {"));
1078        assert!(output.contains("value: i64,"));
1079        assert!(output.contains("self.value"));
1080    }
1081
1082    #[test]
1083    fn generate_string_interpolation() {
1084        let source = r#"
1085            agent Main {
1086                on start {
1087                    let name = "World";
1088                    let msg = "Hello, {name}!";
1089                    print(msg);
1090                    emit(0);
1091                }
1092            }
1093            run Main;
1094        "#;
1095
1096        let output = generate_source(source);
1097        assert!(output.contains("format!(\"Hello, {}!\", name)"));
1098    }
1099
1100    #[test]
1101    fn generate_control_flow() {
1102        let source = r#"
1103            agent Main {
1104                on start {
1105                    let x = 10;
1106                    if x > 5 {
1107                        emit(1);
1108                    } else {
1109                        emit(0);
1110                    }
1111                }
1112            }
1113            run Main;
1114        "#;
1115
1116        let output = generate_source(source);
1117        assert!(output.contains("if (x > 5_i64)"), "output:\n{output}");
1118        // else is on the same line after close brace
1119        assert!(output.contains("else"), "output:\n{output}");
1120    }
1121
1122    #[test]
1123    fn generate_loops() {
1124        let source = r#"
1125            agent Main {
1126                on start {
1127                    for x in [1, 2, 3] {
1128                        print(str(x));
1129                    }
1130                    let n = 0;
1131                    while n < 5 {
1132                        n = n + 1;
1133                    }
1134                    emit(n);
1135                }
1136            }
1137            run Main;
1138        "#;
1139
1140        let output = generate_source(source);
1141        assert!(output.contains("for x in vec![1_i64, 2_i64, 3_i64]"));
1142        assert!(output.contains("while (n < 5_i64)"));
1143    }
1144
1145    #[test]
1146    fn generate_pub_function() {
1147        let source = r#"
1148            pub fn helper(x: Int) -> Int {
1149                return x * 2;
1150            }
1151            agent Main {
1152                on start {
1153                    emit(helper(21));
1154                }
1155            }
1156            run Main;
1157        "#;
1158
1159        let output = generate_source(source);
1160        assert!(output.contains("pub fn helper(x: i64) -> i64"));
1161    }
1162
1163    #[test]
1164    fn generate_pub_agent() {
1165        let source = r#"
1166            pub agent Worker {
1167                on start {
1168                    emit(42);
1169                }
1170            }
1171            agent Main {
1172                on start {
1173                    emit(0);
1174                }
1175            }
1176            run Main;
1177        "#;
1178
1179        let output = generate_source(source);
1180        assert!(output.contains("pub struct Worker;"));
1181    }
1182
1183    #[test]
1184    fn generate_module_tree_simple() {
1185        use sage_loader::load_single_file;
1186        use std::fs;
1187        use tempfile::TempDir;
1188
1189        let dir = TempDir::new().unwrap();
1190        let file = dir.path().join("test.sg");
1191        fs::write(
1192            &file,
1193            r#"
1194agent Main {
1195    on start {
1196        emit(42);
1197    }
1198}
1199run Main;
1200"#,
1201        )
1202        .unwrap();
1203
1204        let tree = load_single_file(&file).unwrap();
1205        let project = generate_module_tree(&tree, "test");
1206
1207        assert!(project.main_rs.contains("struct Main;"));
1208        assert!(project.main_rs.contains("async fn on_start"));
1209        assert!(project.main_rs.contains("#[tokio::main]"));
1210    }
1211
1212    #[test]
1213    fn generate_record_declaration() {
1214        let source = r#"
1215            record Point {
1216                x: Int,
1217                y: Int,
1218            }
1219            agent Main {
1220                on start {
1221                    let p = Point { x: 10, y: 20 };
1222                    emit(p.x);
1223                }
1224            }
1225            run Main;
1226        "#;
1227
1228        let output = generate_source(source);
1229        assert!(output.contains("#[derive(Debug, Clone)]"));
1230        assert!(output.contains("struct Point {"));
1231        assert!(output.contains("x: i64,"));
1232        assert!(output.contains("y: i64,"));
1233        assert!(output.contains("Point { x: 10_i64, y: 20_i64 }"));
1234        assert!(output.contains("p.x"));
1235    }
1236
1237    #[test]
1238    fn generate_enum_declaration() {
1239        let source = r#"
1240            enum Status {
1241                Active,
1242                Inactive,
1243                Pending,
1244            }
1245            agent Main {
1246                on start {
1247                    emit(0);
1248                }
1249            }
1250            run Main;
1251        "#;
1252
1253        let output = generate_source(source);
1254        assert!(output.contains("#[derive(Debug, Clone, Copy, PartialEq, Eq)]"));
1255        assert!(output.contains("enum Status {"));
1256        assert!(output.contains("Active,"));
1257        assert!(output.contains("Inactive,"));
1258        assert!(output.contains("Pending,"));
1259    }
1260
1261    #[test]
1262    fn generate_const_declaration() {
1263        let source = r#"
1264            const MAX_SIZE: Int = 100;
1265            const GREETING: String = "Hello";
1266            agent Main {
1267                on start {
1268                    emit(MAX_SIZE);
1269                }
1270            }
1271            run Main;
1272        "#;
1273
1274        let output = generate_source(source);
1275        assert!(output.contains("const MAX_SIZE: i64 = 100_i64;"));
1276        assert!(output.contains("const GREETING: String = \"Hello\".to_string();"));
1277    }
1278
1279    #[test]
1280    fn generate_match_expression() {
1281        let source = r#"
1282            enum Status {
1283                Active,
1284                Inactive,
1285            }
1286            fn check_status(s: Status) -> Int {
1287                return match s {
1288                    Active => 1,
1289                    Inactive => 0,
1290                };
1291            }
1292            agent Main {
1293                on start {
1294                    emit(0);
1295                }
1296            }
1297            run Main;
1298        "#;
1299
1300        let output = generate_source(source);
1301        assert!(output.contains("match s {"));
1302        assert!(output.contains("Active => 1_i64,"));
1303        assert!(output.contains("Inactive => 0_i64,"));
1304    }
1305
1306    // =========================================================================
1307    // RFC-0007: Error handling codegen tests
1308    // =========================================================================
1309
1310    #[test]
1311    fn generate_fallible_function() {
1312        let source = r#"
1313            fn get_data(url: String) -> String fails {
1314                return url;
1315            }
1316            agent Main {
1317                on start { emit(0); }
1318            }
1319            run Main;
1320        "#;
1321
1322        let output = generate_source(source);
1323        // Fallible function should return SageResult<T>
1324        assert!(output.contains("fn get_data(url: String) -> SageResult<String>"));
1325    }
1326
1327    #[test]
1328    fn generate_try_expression() {
1329        let source = r#"
1330            fn fallible() -> Int fails { return 42; }
1331            fn caller() -> Int fails {
1332                let x = try fallible();
1333                return x;
1334            }
1335            agent Main {
1336                on start { emit(0); }
1337            }
1338            run Main;
1339        "#;
1340
1341        let output = generate_source(source);
1342        // try should generate ? operator
1343        assert!(output.contains("fallible()?"));
1344    }
1345
1346    #[test]
1347    fn generate_catch_expression() {
1348        let source = r#"
1349            fn fallible() -> Int fails { return 42; }
1350            agent Main {
1351                on start {
1352                    let x = fallible() catch { 0 };
1353                    emit(x);
1354                }
1355            }
1356            run Main;
1357        "#;
1358
1359        let output = generate_source(source);
1360        // catch should generate match expression
1361        assert!(output.contains("match fallible()"));
1362        assert!(output.contains("Ok(__val) => __val"));
1363        assert!(output.contains("Err(_) => 0_i64"));
1364    }
1365
1366    #[test]
1367    fn generate_catch_with_binding() {
1368        let source = r#"
1369            fn fallible() -> Int fails { return 42; }
1370            agent Main {
1371                on start {
1372                    let x = fallible() catch(e) { 0 };
1373                    emit(x);
1374                }
1375            }
1376            run Main;
1377        "#;
1378
1379        let output = generate_source(source);
1380        // catch with binding should capture the error
1381        assert!(output.contains("Err(e) => 0_i64"));
1382    }
1383
1384    #[test]
1385    fn generate_on_error_handler() {
1386        let source = r#"
1387            agent Main {
1388                on start {
1389                    emit(0);
1390                }
1391                on error(e) {
1392                    emit(1);
1393                }
1394            }
1395            run Main;
1396        "#;
1397
1398        let output = generate_source(source);
1399        // Should generate on_error method
1400        assert!(output.contains("async fn on_error(self, e: SageError"));
1401        // Main should dispatch to on_error on failure
1402        assert!(output.contains(".on_error(e, ctx)"));
1403    }
1404}