Skip to main content

sage_codegen/
generator.rs

1//! Main code generator.
2
3use crate::emit::Emitter;
4use sage_loader::ModuleTree;
5use sage_parser::{
6    AgentDecl, BinOp, Block, ConstDecl, EnumDecl, EventKind, Expr, FnDecl, Literal, Program,
7    RecordDecl, Stmt, StringPart, UnaryOp,
8};
9use sage_types::TypeExpr;
10
11/// Generated Rust project files.
12pub struct GeneratedProject {
13    /// The main.rs content.
14    pub main_rs: String,
15    /// The Cargo.toml content.
16    pub cargo_toml: String,
17}
18
19/// Generate Rust code from a Sage program (single file).
20pub fn generate(program: &Program, project_name: &str) -> GeneratedProject {
21    let mut gen = Generator::new();
22    let main_rs = gen.generate_program(program);
23    let cargo_toml = gen.generate_cargo_toml(project_name);
24    GeneratedProject {
25        main_rs,
26        cargo_toml,
27    }
28}
29
30/// Generate Rust code from a module tree (multi-file project).
31///
32/// This flattens all modules into a single Rust file, generating all agents
33/// and functions with appropriate visibility modifiers.
34pub fn generate_module_tree(tree: &ModuleTree, project_name: &str) -> GeneratedProject {
35    let mut gen = Generator::new();
36    let main_rs = gen.generate_module_tree(tree);
37    let cargo_toml = gen.generate_cargo_toml(project_name);
38    GeneratedProject {
39        main_rs,
40        cargo_toml,
41    }
42}
43
44struct Generator {
45    emit: Emitter,
46}
47
48impl Generator {
49    fn new() -> Self {
50        Self {
51            emit: Emitter::new(),
52        }
53    }
54
55    fn generate_program(&mut self, program: &Program) -> String {
56        // Prelude
57        self.emit
58            .writeln("//! Generated by Sage compiler. Do not edit.");
59        self.emit.blank_line();
60        self.emit.writeln("use sage_runtime::prelude::*;");
61        self.emit.blank_line();
62
63        // Constants
64        for const_decl in &program.consts {
65            self.generate_const(const_decl);
66            self.emit.blank_line();
67        }
68
69        // Enums
70        for enum_decl in &program.enums {
71            self.generate_enum(enum_decl);
72            self.emit.blank_line();
73        }
74
75        // Records
76        for record in &program.records {
77            self.generate_record(record);
78            self.emit.blank_line();
79        }
80
81        // Functions
82        for func in &program.functions {
83            self.generate_function(func);
84            self.emit.blank_line();
85        }
86
87        // Agents
88        for agent in &program.agents {
89            self.generate_agent(agent);
90            self.emit.blank_line();
91        }
92
93        // Entry point (required for executables)
94        if let Some(run_agent) = &program.run_agent {
95            // Find the entry agent
96            if let Some(agent) = program.agents.iter().find(|a| a.name.name == run_agent.name) {
97                self.generate_main(agent);
98            }
99        }
100
101        std::mem::take(&mut self.emit).finish()
102    }
103
104    fn generate_module_tree(&mut self, tree: &ModuleTree) -> String {
105        // Prelude
106        self.emit
107            .writeln("//! Generated by Sage compiler. Do not edit.");
108        self.emit.blank_line();
109        self.emit.writeln("use sage_runtime::prelude::*;");
110        self.emit.blank_line();
111
112        // Generate all modules, starting with the root
113        // We flatten everything into one file for simplicity
114        // (A more advanced implementation would generate mod.rs files)
115
116        // First, generate non-root modules
117        for (path, module) in &tree.modules {
118            if path != &tree.root {
119                self.emit.write("// Module: ");
120                if path.is_empty() {
121                    self.emit.writeln("(root)");
122                } else {
123                    self.emit.writeln(&path.join("::"));
124                }
125
126                for const_decl in &module.program.consts {
127                    self.generate_const(const_decl);
128                    self.emit.blank_line();
129                }
130
131                for enum_decl in &module.program.enums {
132                    self.generate_enum(enum_decl);
133                    self.emit.blank_line();
134                }
135
136                for record in &module.program.records {
137                    self.generate_record(record);
138                    self.emit.blank_line();
139                }
140
141                for func in &module.program.functions {
142                    self.generate_function(func);
143                    self.emit.blank_line();
144                }
145
146                for agent in &module.program.agents {
147                    self.generate_agent(agent);
148                    self.emit.blank_line();
149                }
150            }
151        }
152
153        // Then, generate the root module
154        if let Some(root_module) = tree.modules.get(&tree.root) {
155            self.emit.writeln("// Root module");
156
157            for const_decl in &root_module.program.consts {
158                self.generate_const(const_decl);
159                self.emit.blank_line();
160            }
161
162            for enum_decl in &root_module.program.enums {
163                self.generate_enum(enum_decl);
164                self.emit.blank_line();
165            }
166
167            for record in &root_module.program.records {
168                self.generate_record(record);
169                self.emit.blank_line();
170            }
171
172            for func in &root_module.program.functions {
173                self.generate_function(func);
174                self.emit.blank_line();
175            }
176
177            for agent in &root_module.program.agents {
178                self.generate_agent(agent);
179                self.emit.blank_line();
180            }
181
182            // Entry point (only in root module)
183            if let Some(run_agent) = &root_module.program.run_agent {
184                // Find the entry agent
185                if let Some(agent) = root_module
186                    .program
187                    .agents
188                    .iter()
189                    .find(|a| a.name.name == run_agent.name)
190                {
191                    self.generate_main(agent);
192                }
193            }
194        }
195
196        std::mem::take(&mut self.emit).finish()
197    }
198
199    fn generate_cargo_toml(&self, name: &str) -> String {
200        // Use a relative path that works from target/sage/<project>/
201        // This assumes the standard project layout
202        format!(
203            r#"[package]
204name = "{name}"
205version = "0.1.0"
206edition = "2021"
207
208[dependencies]
209sage-runtime = {{ path = "../../../crates/sage-runtime" }}
210tokio = {{ version = "1", features = ["full"] }}
211serde = {{ version = "1", features = ["derive"] }}
212serde_json = "1"
213
214# Standalone project, not part of parent workspace
215[workspace]
216"#
217        )
218    }
219
220    fn generate_const(&mut self, const_decl: &ConstDecl) {
221        if const_decl.is_pub {
222            self.emit.write("pub ");
223        }
224        self.emit.write("const ");
225        self.emit.write(&const_decl.name.name);
226        self.emit.write(": ");
227        self.emit_type(&const_decl.ty);
228        self.emit.write(" = ");
229        self.generate_expr(&const_decl.value);
230        self.emit.writeln(";");
231    }
232
233    fn generate_enum(&mut self, enum_decl: &EnumDecl) {
234        if enum_decl.is_pub {
235            self.emit.write("pub ");
236        }
237        self.emit
238            .writeln("#[derive(Debug, Clone, Copy, PartialEq, Eq)]");
239        self.emit.write("enum ");
240        self.emit.write(&enum_decl.name.name);
241        self.emit.writeln(" {");
242        self.emit.indent();
243        for variant in &enum_decl.variants {
244            self.emit.write(&variant.name.name);
245            if let Some(payload_ty) = &variant.payload {
246                self.emit.write("(");
247                self.emit_type(payload_ty);
248                self.emit.write(")");
249            }
250            self.emit.writeln(",");
251        }
252        self.emit.dedent();
253        self.emit.writeln("}");
254    }
255
256    fn generate_record(&mut self, record: &RecordDecl) {
257        if record.is_pub {
258            self.emit.write("pub ");
259        }
260        self.emit.writeln("#[derive(Debug, Clone)]");
261        self.emit.write("struct ");
262        self.emit.write(&record.name.name);
263        self.emit.writeln(" {");
264        self.emit.indent();
265        for field in &record.fields {
266            self.emit.write(&field.name.name);
267            self.emit.write(": ");
268            self.emit_type(&field.ty);
269            self.emit.writeln(",");
270        }
271        self.emit.dedent();
272        self.emit.writeln("}");
273    }
274
275    fn generate_function(&mut self, func: &FnDecl) {
276        // Function signature with visibility
277        if func.is_pub {
278            self.emit.write("pub ");
279        }
280        self.emit.write("fn ");
281        self.emit.write(&func.name.name);
282        self.emit.write("(");
283
284        for (i, param) in func.params.iter().enumerate() {
285            if i > 0 {
286                self.emit.write(", ");
287            }
288            self.emit.write(&param.name.name);
289            self.emit.write(": ");
290            self.emit_type(&param.ty);
291        }
292
293        self.emit.write(") -> ");
294
295        // RFC-0007: Wrap return type in SageResult if fallible
296        if func.is_fallible {
297            self.emit.write("SageResult<");
298            self.emit_type(&func.return_ty);
299            self.emit.write(">");
300        } else {
301            self.emit_type(&func.return_ty);
302        }
303
304        self.emit.write(" ");
305        self.generate_block(&func.body);
306    }
307
308    fn generate_agent(&mut self, agent: &AgentDecl) {
309        let name = &agent.name.name;
310
311        // RFC-0011: Check for tool usage
312        let has_tools = !agent.tool_uses.is_empty();
313        let needs_struct_body = !agent.beliefs.is_empty() || has_tools;
314
315        // Struct definition with visibility
316        if agent.is_pub {
317            self.emit.write("pub ");
318        }
319        self.emit.write("struct ");
320        self.emit.write(name);
321        if !needs_struct_body {
322            self.emit.writeln(";");
323        } else {
324            self.emit.writeln(" {");
325            self.emit.indent();
326
327            // RFC-0011: Generate tool fields
328            for tool_use in &agent.tool_uses {
329                // Generate field like: http: HttpClient
330                self.emit.write(&tool_use.name.to_lowercase());
331                self.emit.write(": ");
332                self.emit.write(&tool_use.name);
333                self.emit.writeln("Client,");
334            }
335
336            // Regular belief fields
337            for belief in &agent.beliefs {
338                self.emit.write(&belief.name.name);
339                self.emit.write(": ");
340                self.emit_type(&belief.ty);
341                self.emit.writeln(",");
342            }
343            self.emit.dedent();
344            self.emit.writeln("}");
345        }
346        self.emit.blank_line();
347
348        // Find the output type from the start handler
349        let output_type = self.infer_agent_output_type(agent);
350
351        // Impl block
352        self.emit.write("impl ");
353        self.emit.write(name);
354        self.emit.writeln(" {");
355        self.emit.indent();
356
357        // Generate handlers
358        for handler in &agent.handlers {
359            match &handler.event {
360                EventKind::Start => {
361                    self.emit
362                        .write("async fn on_start(self, ctx: &mut AgentContext<");
363                    self.emit.write(&output_type);
364                    self.emit.write(">) -> SageResult<");
365                    self.emit.write(&output_type);
366                    self.emit.writeln("> {");
367                    self.emit.indent();
368                    self.generate_block_contents(&handler.body);
369                    self.emit.dedent();
370                    self.emit.writeln("}");
371                }
372
373                // RFC-0007: Generate on_error handler
374                EventKind::Error { param_name } => {
375                    self.emit.write("async fn on_error(self, ");
376                    self.emit.write(&param_name.name);
377                    self.emit.write(": SageError, ctx: &mut AgentContext<");
378                    self.emit.write(&output_type);
379                    self.emit.write(">) -> SageResult<");
380                    self.emit.write(&output_type);
381                    self.emit.writeln("> {");
382                    self.emit.indent();
383                    self.generate_block_contents(&handler.body);
384                    self.emit.dedent();
385                    self.emit.writeln("}");
386                }
387
388                // Other handlers (message, stop) - future work
389                _ => {}
390            }
391        }
392
393        self.emit.dedent();
394        self.emit.writeln("}");
395    }
396
397    fn generate_main(&mut self, agent: &AgentDecl) {
398        let entry_agent = &agent.name.name;
399        let has_error_handler = agent
400            .handlers
401            .iter()
402            .any(|h| matches!(h.event, EventKind::Error { .. }));
403
404        // RFC-0011: Check if agent uses tools
405        let has_tools = !agent.tool_uses.is_empty();
406
407        self.emit.writeln("#[tokio::main]");
408        self.emit
409            .writeln("async fn main() -> Result<(), Box<dyn std::error::Error>> {");
410        self.emit.indent();
411
412        // Helper to generate agent construction (with or without tool fields)
413        let agent_construct = if has_tools {
414            let mut s = format!("{entry_agent} {{ ");
415            for (i, tool_use) in agent.tool_uses.iter().enumerate() {
416                if i > 0 {
417                    s.push_str(", ");
418                }
419                // Generate: http: HttpClient::from_env()
420                s.push_str(&tool_use.name.to_lowercase());
421                s.push_str(": ");
422                s.push_str(&tool_use.name);
423                s.push_str("Client::from_env()");
424            }
425            s.push_str(" }");
426            s
427        } else {
428            entry_agent.to_string()
429        };
430
431        if has_error_handler {
432            // RFC-0007: Generate error dispatch code
433            // Handlers take &mut ctx, so no cloning needed
434            self.emit
435                .writeln("let handle = sage_runtime::spawn(|mut ctx| async move {");
436            self.emit.indent();
437            self.emit.write("let agent = ");
438            self.emit.write(&agent_construct);
439            self.emit.writeln(";");
440            self.emit.writeln("match agent.on_start(&mut ctx).await {");
441            self.emit.indent();
442            self.emit.writeln("Ok(result) => Ok(result),");
443            self.emit.write("Err(e) => ");
444            self.emit.write(&agent_construct);
445            self.emit.writeln(".on_error(e, &mut ctx).await,");
446            self.emit.dedent();
447            self.emit.writeln("}");
448            self.emit.dedent();
449            self.emit.writeln("});");
450        } else {
451            // Simple case: no error handler
452            self.emit.write("let handle = sage_runtime::spawn(|mut ctx| ");
453            self.emit.write(&agent_construct);
454            self.emit.writeln(".on_start(&mut ctx));");
455        }
456
457        self.emit.writeln("let result = handle.result().await?;");
458        self.emit.writeln("println!(\"{:?}\", result);");
459        self.emit.writeln("Ok(())");
460
461        self.emit.dedent();
462        self.emit.writeln("}");
463    }
464
465    fn generate_block(&mut self, block: &Block) {
466        self.emit.open_brace();
467        self.generate_block_contents(block);
468        self.emit.close_brace();
469    }
470
471    fn generate_block_inline(&mut self, block: &Block) {
472        self.emit.open_brace();
473        self.generate_block_contents(block);
474        self.emit.close_brace_inline();
475    }
476
477    fn generate_block_contents(&mut self, block: &Block) {
478        for stmt in &block.stmts {
479            self.generate_stmt(stmt);
480        }
481    }
482
483    fn generate_stmt(&mut self, stmt: &Stmt) {
484        match stmt {
485            Stmt::Let {
486                name, ty, value, ..
487            } => {
488                self.emit.write("let ");
489                if ty.is_some() {
490                    self.emit.write(&name.name);
491                    self.emit.write(": ");
492                    self.emit_type(ty.as_ref().unwrap());
493                } else {
494                    self.emit.write(&name.name);
495                }
496                self.emit.write(" = ");
497                self.generate_expr(value);
498                self.emit.writeln(";");
499            }
500
501            Stmt::Assign { name, value, .. } => {
502                self.emit.write(&name.name);
503                self.emit.write(" = ");
504                self.generate_expr(value);
505                self.emit.writeln(";");
506            }
507
508            Stmt::Return { value, .. } => {
509                self.emit.write("return ");
510                if let Some(expr) = value {
511                    self.generate_expr(expr);
512                }
513                self.emit.writeln(";");
514            }
515
516            Stmt::If {
517                condition,
518                then_block,
519                else_block,
520                ..
521            } => {
522                self.emit.write("if ");
523                self.generate_expr(condition);
524                self.emit.write(" ");
525                if else_block.is_some() {
526                    self.generate_block_inline(then_block);
527                    self.emit.write(" else ");
528                    match else_block.as_ref().unwrap() {
529                        sage_parser::ElseBranch::Block(block) => {
530                            self.generate_block(block);
531                        }
532                        sage_parser::ElseBranch::ElseIf(stmt) => {
533                            self.generate_stmt(stmt);
534                        }
535                    }
536                } else {
537                    self.generate_block(then_block);
538                }
539            }
540
541            Stmt::For {
542                pattern,
543                iter,
544                body,
545                ..
546            } => {
547                self.emit.write("for ");
548                self.emit_pattern(pattern);
549                self.emit.write(" in ");
550                self.generate_expr(iter);
551                self.emit.write(" ");
552                self.generate_block(body);
553            }
554
555            Stmt::While {
556                condition, body, ..
557            } => {
558                self.emit.write("while ");
559                self.generate_expr(condition);
560                self.emit.write(" ");
561                self.generate_block(body);
562            }
563
564            Stmt::Loop { body, .. } => {
565                self.emit.write("loop ");
566                self.generate_block(body);
567            }
568
569            Stmt::Break { .. } => {
570                self.emit.writeln("break;");
571            }
572
573            Stmt::Expr { expr, .. } => {
574                // Handle emit specially
575                if let Expr::Emit { value, .. } = expr {
576                    self.emit.write("return ctx.emit(");
577                    self.generate_expr(value);
578                    self.emit.writeln(");");
579                } else {
580                    self.generate_expr(expr);
581                    self.emit.writeln(";");
582                }
583            }
584
585            Stmt::LetTuple { names, value, .. } => {
586                self.emit.write("let (");
587                for (i, name) in names.iter().enumerate() {
588                    if i > 0 {
589                        self.emit.write(", ");
590                    }
591                    self.emit.write(&name.name);
592                }
593                self.emit.write(") = ");
594                self.generate_expr(value);
595                self.emit.writeln(";");
596            }
597        }
598    }
599
600    fn generate_expr(&mut self, expr: &Expr) {
601        match expr {
602            Expr::Literal { value, .. } => {
603                self.emit_literal(value);
604            }
605
606            Expr::Var { name, .. } => {
607                self.emit.write(&name.name);
608            }
609
610            Expr::Binary {
611                op, left, right, ..
612            } => {
613                // Handle string concatenation specially
614                if matches!(op, BinOp::Concat) {
615                    self.emit.write("format!(\"{}{}\", ");
616                    self.generate_expr(left);
617                    self.emit.write(", ");
618                    self.generate_expr(right);
619                    self.emit.write(")");
620                } else {
621                    self.emit.write("(");
622                    self.generate_expr(left);
623                    self.emit.write(" ");
624                    self.emit_binop(op);
625                    self.emit.write(" ");
626                    self.generate_expr(right);
627                    self.emit.write(")");
628                }
629            }
630
631            Expr::Unary { op, operand, .. } => {
632                self.emit_unaryop(op);
633                self.generate_expr(operand);
634            }
635
636            Expr::Call { name, args, .. } => {
637                let fn_name = &name.name;
638
639                // Handle builtins
640                match fn_name.as_str() {
641                    "print" => {
642                        self.emit.write("println!(\"{}\", ");
643                        self.generate_expr(&args[0]);
644                        self.emit.write(")");
645                    }
646                    "str" => {
647                        self.generate_expr(&args[0]);
648                        self.emit.write(".to_string()");
649                    }
650                    "len" => {
651                        self.generate_expr(&args[0]);
652                        self.emit.write(".len() as i64");
653                    }
654                    _ => {
655                        self.emit.write(fn_name);
656                        self.emit.write("(");
657                        for (i, arg) in args.iter().enumerate() {
658                            if i > 0 {
659                                self.emit.write(", ");
660                            }
661                            self.generate_expr(arg);
662                        }
663                        self.emit.write(")");
664                    }
665                }
666            }
667
668            Expr::SelfField { field, .. } => {
669                self.emit.write("self.");
670                self.emit.write(&field.name);
671            }
672
673            Expr::SelfMethodCall { method, args, .. } => {
674                self.emit.write("self.");
675                self.emit.write(&method.name);
676                self.emit.write("(");
677                for (i, arg) in args.iter().enumerate() {
678                    if i > 0 {
679                        self.emit.write(", ");
680                    }
681                    self.generate_expr(arg);
682                }
683                self.emit.write(")");
684            }
685
686            Expr::List { elements, .. } => {
687                self.emit.write("vec![");
688                for (i, elem) in elements.iter().enumerate() {
689                    if i > 0 {
690                        self.emit.write(", ");
691                    }
692                    self.generate_expr(elem);
693                }
694                self.emit.write("]");
695            }
696
697            Expr::Paren { inner, .. } => {
698                self.emit.write("(");
699                self.generate_expr(inner);
700                self.emit.write(")");
701            }
702
703            Expr::Infer { template, .. } => {
704                self.emit.write("ctx.infer_string(&");
705                self.emit_string_template(template);
706                self.emit.write(").await?");
707            }
708
709            Expr::Spawn { agent, fields, .. } => {
710                self.emit.write("sage_runtime::spawn(|ctx| ");
711                self.emit.write(&agent.name);
712                if fields.is_empty() {
713                    self.emit.write(".on_start(ctx))");
714                } else {
715                    self.emit.write(" { ");
716                    for (i, field) in fields.iter().enumerate() {
717                        if i > 0 {
718                            self.emit.write(", ");
719                        }
720                        self.emit.write(&field.name.name);
721                        self.emit.write(": ");
722                        self.generate_expr(&field.value);
723                    }
724                    self.emit.write(" }.on_start(ctx))");
725                }
726            }
727
728            Expr::Await { handle, .. } => {
729                self.generate_expr(handle);
730                self.emit.write(".result().await?");
731            }
732
733            Expr::Send {
734                handle, message, ..
735            } => {
736                self.generate_expr(handle);
737                self.emit.write(".send(sage_runtime::Message::new(");
738                self.generate_expr(message);
739                self.emit.write(")?).await?");
740            }
741
742            Expr::Emit { value, .. } => {
743                self.emit.write("ctx.emit(");
744                self.generate_expr(value);
745                self.emit.write(")");
746            }
747
748            Expr::StringInterp { template, .. } => {
749                self.emit_string_template(template);
750            }
751
752            Expr::Match {
753                scrutinee, arms, ..
754            } => {
755                self.emit.write("match ");
756                self.generate_expr(scrutinee);
757                self.emit.writeln(" {");
758                self.emit.indent();
759                for arm in arms {
760                    self.emit_pattern(&arm.pattern);
761                    self.emit.write(" => ");
762                    self.generate_expr(&arm.body);
763                    self.emit.writeln(",");
764                }
765                self.emit.dedent();
766                self.emit.write("}");
767            }
768
769            Expr::RecordConstruct { name, fields, .. } => {
770                self.emit.write(&name.name);
771                self.emit.write(" { ");
772                for (i, field) in fields.iter().enumerate() {
773                    if i > 0 {
774                        self.emit.write(", ");
775                    }
776                    self.emit.write(&field.name.name);
777                    self.emit.write(": ");
778                    self.generate_expr(&field.value);
779                }
780                self.emit.write(" }");
781            }
782
783            Expr::FieldAccess { object, field, .. } => {
784                self.generate_expr(object);
785                self.emit.write(".");
786                self.emit.write(&field.name);
787            }
788
789            Expr::Receive { .. } => {
790                self.emit.write("ctx.receive().await?");
791            }
792
793            // RFC-0007: Error handling
794            Expr::Try { expr, .. } => {
795                // Generate the inner expression with ? for error propagation
796                self.generate_expr(expr);
797                self.emit.write("?");
798            }
799
800            Expr::Catch {
801                expr,
802                error_bind,
803                recovery,
804                ..
805            } => {
806                // Generate a match expression to handle the Result
807                self.emit.write("match ");
808                self.generate_expr(expr);
809                self.emit.writeln(" {");
810                self.emit.indent();
811
812                // Ok arm - unwrap the value
813                self.emit.writeln("Ok(__val) => __val,");
814
815                // Err arm - run recovery
816                if let Some(err_name) = error_bind {
817                    self.emit.write("Err(");
818                    self.emit.write(&err_name.name);
819                    self.emit.write(") => ");
820                } else {
821                    self.emit.write("Err(_) => ");
822                }
823                self.generate_expr(recovery);
824                self.emit.writeln(",");
825
826                self.emit.dedent();
827                self.emit.write("}");
828            }
829
830            // RFC-0009: Closures
831            Expr::Closure { params, body, .. } => {
832                // Generate: Box::new(move |param1: Type1, param2: Type2| { body })
833                self.emit.write("Box::new(move |");
834                for (i, param) in params.iter().enumerate() {
835                    if i > 0 {
836                        self.emit.write(", ");
837                    }
838                    self.emit.write(&param.name.name);
839                    if let Some(ty) = &param.ty {
840                        self.emit.write(": ");
841                        self.emit_type(ty);
842                    }
843                }
844                self.emit.write("| ");
845                self.generate_expr(body);
846                self.emit.write(")");
847            }
848
849            // RFC-0010: Tuples and Maps
850            Expr::Tuple { elements, .. } => {
851                self.emit.write("(");
852                for (i, elem) in elements.iter().enumerate() {
853                    if i > 0 {
854                        self.emit.write(", ");
855                    }
856                    self.generate_expr(elem);
857                }
858                self.emit.write(")");
859            }
860
861            Expr::TupleIndex { tuple, index, .. } => {
862                self.generate_expr(tuple);
863                self.emit.write(&format!(".{index}"));
864            }
865
866            Expr::Map { entries, .. } => {
867                if entries.is_empty() {
868                    self.emit.write("std::collections::HashMap::new()");
869                } else {
870                    self.emit.write("std::collections::HashMap::from([");
871                    for (i, entry) in entries.iter().enumerate() {
872                        if i > 0 {
873                            self.emit.write(", ");
874                        }
875                        self.emit.write("(");
876                        self.generate_expr(&entry.key);
877                        self.emit.write(", ");
878                        self.generate_expr(&entry.value);
879                        self.emit.write(")");
880                    }
881                    self.emit.write("])");
882                }
883            }
884
885            Expr::VariantConstruct {
886                enum_name,
887                variant,
888                payload,
889                ..
890            } => {
891                self.emit.write(&enum_name.name);
892                self.emit.write("::");
893                self.emit.write(&variant.name);
894                if let Some(payload_expr) = payload {
895                    self.emit.write("(");
896                    self.generate_expr(payload_expr);
897                    self.emit.write(")");
898                }
899            }
900
901            // RFC-0011: Tool calls
902            Expr::ToolCall {
903                tool,
904                function,
905                args,
906                ..
907            } => {
908                // Generate: self.tool_name.function(args).await
909                // Returns SageResult<T> - must be handled with try/catch
910                self.emit.write("self.");
911                self.emit.write(&tool.name.to_lowercase());
912                self.emit.write(".");
913                self.emit.write(&function.name);
914                self.emit.write("(");
915                for (i, arg) in args.iter().enumerate() {
916                    if i > 0 {
917                        self.emit.write(", ");
918                    }
919                    self.generate_expr(arg);
920                }
921                self.emit.write(").await");
922            }
923        }
924    }
925
926    fn emit_pattern(&mut self, pattern: &sage_parser::Pattern) {
927        use sage_parser::Pattern;
928        match pattern {
929            Pattern::Wildcard { .. } => {
930                self.emit.write("_");
931            }
932            Pattern::Variant {
933                enum_name,
934                variant,
935                payload,
936                ..
937            } => {
938                if let Some(enum_name) = enum_name {
939                    self.emit.write(&enum_name.name);
940                    self.emit.write("::");
941                }
942                self.emit.write(&variant.name);
943                if let Some(inner_pattern) = payload {
944                    self.emit.write("(");
945                    self.emit_pattern(inner_pattern);
946                    self.emit.write(")");
947                }
948            }
949            Pattern::Literal { value, .. } => {
950                self.emit_literal(value);
951            }
952            Pattern::Binding { name, .. } => {
953                self.emit.write(&name.name);
954            }
955            Pattern::Tuple { elements, .. } => {
956                self.emit.write("(");
957                for (i, elem) in elements.iter().enumerate() {
958                    if i > 0 {
959                        self.emit.write(", ");
960                    }
961                    self.emit_pattern(elem);
962                }
963                self.emit.write(")");
964            }
965        }
966    }
967
968    fn emit_literal(&mut self, lit: &Literal) {
969        match lit {
970            Literal::Int(n) => {
971                self.emit.write(&format!("{n}_i64"));
972            }
973            Literal::Float(f) => {
974                self.emit.write(&format!("{f}_f64"));
975            }
976            Literal::Bool(b) => {
977                self.emit.write(if *b { "true" } else { "false" });
978            }
979            Literal::String(s) => {
980                // Escape the string for Rust
981                self.emit.write("\"");
982                for c in s.chars() {
983                    match c {
984                        '"' => self.emit.write_raw("\\\""),
985                        '\\' => self.emit.write_raw("\\\\"),
986                        '\n' => self.emit.write_raw("\\n"),
987                        '\r' => self.emit.write_raw("\\r"),
988                        '\t' => self.emit.write_raw("\\t"),
989                        _ => self.emit.write_raw(&c.to_string()),
990                    }
991                }
992                self.emit.write("\".to_string()");
993            }
994        }
995    }
996
997    fn emit_string_template(&mut self, template: &sage_parser::StringTemplate) {
998        if !template.has_interpolations() {
999            // Simple string literal
1000            if let Some(StringPart::Literal(s)) = template.parts.first() {
1001                self.emit.write("\"");
1002                self.emit.write_raw(s);
1003                self.emit.write("\".to_string()");
1004            }
1005            return;
1006        }
1007
1008        // Build format string and args
1009        self.emit.write("format!(\"");
1010        for part in &template.parts {
1011            match part {
1012                StringPart::Literal(s) => {
1013                    // Escape braces for format string
1014                    let escaped = s.replace('{', "{{").replace('}', "}}");
1015                    self.emit.write_raw(&escaped);
1016                }
1017                StringPart::Interpolation(_) => {
1018                    self.emit.write_raw("{}");
1019                }
1020            }
1021        }
1022        self.emit.write("\"");
1023
1024        // Add the interpolation args
1025        for part in &template.parts {
1026            if let StringPart::Interpolation(ident) = part {
1027                self.emit.write(", ");
1028                self.emit.write(&ident.name);
1029            }
1030        }
1031        self.emit.write(")");
1032    }
1033
1034    fn emit_type(&mut self, ty: &TypeExpr) {
1035        match ty {
1036            TypeExpr::Int => self.emit.write("i64"),
1037            TypeExpr::Float => self.emit.write("f64"),
1038            TypeExpr::Bool => self.emit.write("bool"),
1039            TypeExpr::String => self.emit.write("String"),
1040            TypeExpr::Unit => self.emit.write("()"),
1041            TypeExpr::List(inner) => {
1042                self.emit.write("Vec<");
1043                self.emit_type(inner);
1044                self.emit.write(">");
1045            }
1046            TypeExpr::Option(inner) => {
1047                self.emit.write("Option<");
1048                self.emit_type(inner);
1049                self.emit.write(">");
1050            }
1051            TypeExpr::Inferred(inner) => {
1052                // Inferred<T> just becomes T at runtime
1053                self.emit_type(inner);
1054            }
1055            TypeExpr::Agent(agent_name) => {
1056                // Agent handles use the agent's output type, but we don't know it here
1057                // For now, just use a generic output type
1058                self.emit.write("AgentHandle<");
1059                self.emit.write(&agent_name.name);
1060                self.emit.write("Output>");
1061            }
1062            TypeExpr::Named(name) => {
1063                self.emit.write(&name.name);
1064            }
1065
1066            // RFC-0007: Error handling
1067            TypeExpr::Error => {
1068                self.emit.write("sage_runtime::SageError");
1069            }
1070
1071            // RFC-0009: Function types
1072            TypeExpr::Fn(params, ret) => {
1073                self.emit.write("Box<dyn Fn(");
1074                for (i, param) in params.iter().enumerate() {
1075                    if i > 0 {
1076                        self.emit.write(", ");
1077                    }
1078                    self.emit_type(param);
1079                }
1080                self.emit.write(") -> ");
1081                self.emit_type(ret);
1082                self.emit.write(" + Send + 'static>");
1083            }
1084
1085            // RFC-0010: Maps, tuples, Result
1086            TypeExpr::Map(key, value) => {
1087                self.emit.write("std::collections::HashMap<");
1088                self.emit_type(key);
1089                self.emit.write(", ");
1090                self.emit_type(value);
1091                self.emit.write(">");
1092            }
1093            TypeExpr::Tuple(elems) => {
1094                self.emit.write("(");
1095                for (i, elem) in elems.iter().enumerate() {
1096                    if i > 0 {
1097                        self.emit.write(", ");
1098                    }
1099                    self.emit_type(elem);
1100                }
1101                self.emit.write(")");
1102            }
1103            TypeExpr::Result(ok, err) => {
1104                self.emit.write("Result<");
1105                self.emit_type(ok);
1106                self.emit.write(", ");
1107                self.emit_type(err);
1108                self.emit.write(">");
1109            }
1110        }
1111    }
1112
1113    fn emit_binop(&mut self, op: &BinOp) {
1114        let s = match op {
1115            BinOp::Add => "+",
1116            BinOp::Sub => "-",
1117            BinOp::Mul => "*",
1118            BinOp::Div => "/",
1119            BinOp::Eq => "==",
1120            BinOp::Ne => "!=",
1121            BinOp::Lt => "<",
1122            BinOp::Gt => ">",
1123            BinOp::Le => "<=",
1124            BinOp::Ge => ">=",
1125            BinOp::And => "&&",
1126            BinOp::Or => "||",
1127            BinOp::Concat => "++", // Handled specially above
1128        };
1129        self.emit.write(s);
1130    }
1131
1132    fn emit_unaryop(&mut self, op: &UnaryOp) {
1133        let s = match op {
1134            UnaryOp::Neg => "-",
1135            UnaryOp::Not => "!",
1136        };
1137        self.emit.write(s);
1138    }
1139
1140    fn infer_agent_output_type(&self, agent: &AgentDecl) -> String {
1141        // Look for emit expression in start handler to infer return type
1142        // For now, default to i64
1143        for handler in &agent.handlers {
1144            if let EventKind::Start = &handler.event {
1145                if let Some(ty) = self.find_emit_type(&handler.body) {
1146                    return ty;
1147                }
1148            }
1149        }
1150        "i64".to_string()
1151    }
1152
1153    fn find_emit_type(&self, block: &Block) -> Option<String> {
1154        for stmt in &block.stmts {
1155            if let Stmt::Expr { expr, .. } = stmt {
1156                if let Expr::Emit { value, .. } = expr {
1157                    return Some(self.infer_expr_type(value));
1158                }
1159            }
1160            // Check nested blocks
1161            if let Stmt::If {
1162                then_block,
1163                else_block,
1164                ..
1165            } = stmt
1166            {
1167                if let Some(ty) = self.find_emit_type(then_block) {
1168                    return Some(ty);
1169                }
1170                if let Some(else_branch) = else_block {
1171                    if let sage_parser::ElseBranch::Block(block) = else_branch {
1172                        if let Some(ty) = self.find_emit_type(block) {
1173                            return Some(ty);
1174                        }
1175                    }
1176                }
1177            }
1178        }
1179        None
1180    }
1181
1182    fn infer_expr_type(&self, expr: &Expr) -> String {
1183        match expr {
1184            Expr::Literal { value, .. } => match value {
1185                Literal::Int(_) => "i64".to_string(),
1186                Literal::Float(_) => "f64".to_string(),
1187                Literal::Bool(_) => "bool".to_string(),
1188                Literal::String(_) => "String".to_string(),
1189            },
1190            Expr::Var { .. } => "i64".to_string(), // Conservative default
1191            Expr::Binary { op, .. } => {
1192                if matches!(
1193                    op,
1194                    BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge
1195                ) {
1196                    "bool".to_string()
1197                } else if matches!(op, BinOp::Concat) {
1198                    "String".to_string()
1199                } else {
1200                    "i64".to_string()
1201                }
1202            }
1203            Expr::Infer { .. } | Expr::StringInterp { .. } => "String".to_string(),
1204            Expr::Call { name, .. } if name.name == "str" => "String".to_string(),
1205            Expr::Call { name, .. } if name.name == "len" => "i64".to_string(),
1206            _ => "i64".to_string(),
1207        }
1208    }
1209}
1210
1211#[cfg(test)]
1212mod tests {
1213    use super::*;
1214    use sage_lexer::lex;
1215    use sage_parser::parse;
1216    use std::sync::Arc;
1217
1218    fn generate_source(source: &str) -> String {
1219        let lex_result = lex(source).expect("lexing failed");
1220        let source_arc: Arc<str> = Arc::from(source);
1221        let (program, errors) = parse(lex_result.tokens(), source_arc);
1222        assert!(errors.is_empty(), "parse errors: {errors:?}");
1223        let program = program.expect("should parse");
1224        generate(&program, "test").main_rs
1225    }
1226
1227    #[test]
1228    fn generate_minimal_program() {
1229        let source = r#"
1230            agent Main {
1231                on start {
1232                    emit(42);
1233                }
1234            }
1235            run Main;
1236        "#;
1237
1238        let output = generate_source(source);
1239        assert!(output.contains("struct Main;"));
1240        assert!(output.contains("async fn on_start"));
1241        assert!(output.contains("ctx.emit(42_i64)"));
1242        assert!(output.contains("#[tokio::main]"));
1243    }
1244
1245    #[test]
1246    fn generate_function() {
1247        let source = r#"
1248            fn add(a: Int, b: Int) -> Int {
1249                return a + b;
1250            }
1251            agent Main {
1252                on start {
1253                    emit(add(1, 2));
1254                }
1255            }
1256            run Main;
1257        "#;
1258
1259        let output = generate_source(source);
1260        assert!(output.contains("fn add(a: i64, b: i64) -> i64"));
1261        assert!(output.contains("return (a + b);"));
1262    }
1263
1264    #[test]
1265    fn generate_agent_with_beliefs() {
1266        let source = r#"
1267            agent Worker {
1268                value: Int
1269
1270                on start {
1271                    emit(self.value * 2);
1272                }
1273            }
1274            agent Main {
1275                on start {
1276                    emit(0);
1277                }
1278            }
1279            run Main;
1280        "#;
1281
1282        let output = generate_source(source);
1283        assert!(output.contains("struct Worker {"));
1284        assert!(output.contains("value: i64,"));
1285        assert!(output.contains("self.value"));
1286    }
1287
1288    #[test]
1289    fn generate_string_interpolation() {
1290        let source = r#"
1291            agent Main {
1292                on start {
1293                    let name = "World";
1294                    let msg = "Hello, {name}!";
1295                    print(msg);
1296                    emit(0);
1297                }
1298            }
1299            run Main;
1300        "#;
1301
1302        let output = generate_source(source);
1303        assert!(output.contains("format!(\"Hello, {}!\", name)"));
1304    }
1305
1306    #[test]
1307    fn generate_control_flow() {
1308        let source = r#"
1309            agent Main {
1310                on start {
1311                    let x = 10;
1312                    if x > 5 {
1313                        emit(1);
1314                    } else {
1315                        emit(0);
1316                    }
1317                }
1318            }
1319            run Main;
1320        "#;
1321
1322        let output = generate_source(source);
1323        assert!(output.contains("if (x > 5_i64)"), "output:\n{output}");
1324        // else is on the same line after close brace
1325        assert!(output.contains("else"), "output:\n{output}");
1326    }
1327
1328    #[test]
1329    fn generate_loops() {
1330        let source = r#"
1331            agent Main {
1332                on start {
1333                    for x in [1, 2, 3] {
1334                        print(str(x));
1335                    }
1336                    let n = 0;
1337                    while n < 5 {
1338                        n = n + 1;
1339                    }
1340                    emit(n);
1341                }
1342            }
1343            run Main;
1344        "#;
1345
1346        let output = generate_source(source);
1347        assert!(output.contains("for x in vec![1_i64, 2_i64, 3_i64]"));
1348        assert!(output.contains("while (n < 5_i64)"));
1349    }
1350
1351    #[test]
1352    fn generate_pub_function() {
1353        let source = r#"
1354            pub fn helper(x: Int) -> Int {
1355                return x * 2;
1356            }
1357            agent Main {
1358                on start {
1359                    emit(helper(21));
1360                }
1361            }
1362            run Main;
1363        "#;
1364
1365        let output = generate_source(source);
1366        assert!(output.contains("pub fn helper(x: i64) -> i64"));
1367    }
1368
1369    #[test]
1370    fn generate_pub_agent() {
1371        let source = r#"
1372            pub agent Worker {
1373                on start {
1374                    emit(42);
1375                }
1376            }
1377            agent Main {
1378                on start {
1379                    emit(0);
1380                }
1381            }
1382            run Main;
1383        "#;
1384
1385        let output = generate_source(source);
1386        assert!(output.contains("pub struct Worker;"));
1387    }
1388
1389    #[test]
1390    fn generate_module_tree_simple() {
1391        use sage_loader::load_single_file;
1392        use std::fs;
1393        use tempfile::TempDir;
1394
1395        let dir = TempDir::new().unwrap();
1396        let file = dir.path().join("test.sg");
1397        fs::write(
1398            &file,
1399            r#"
1400agent Main {
1401    on start {
1402        emit(42);
1403    }
1404}
1405run Main;
1406"#,
1407        )
1408        .unwrap();
1409
1410        let tree = load_single_file(&file).unwrap();
1411        let project = generate_module_tree(&tree, "test");
1412
1413        assert!(project.main_rs.contains("struct Main;"));
1414        assert!(project.main_rs.contains("async fn on_start"));
1415        assert!(project.main_rs.contains("#[tokio::main]"));
1416    }
1417
1418    #[test]
1419    fn generate_record_declaration() {
1420        let source = r#"
1421            record Point {
1422                x: Int,
1423                y: Int,
1424            }
1425            agent Main {
1426                on start {
1427                    let p = Point { x: 10, y: 20 };
1428                    emit(p.x);
1429                }
1430            }
1431            run Main;
1432        "#;
1433
1434        let output = generate_source(source);
1435        assert!(output.contains("#[derive(Debug, Clone)]"));
1436        assert!(output.contains("struct Point {"));
1437        assert!(output.contains("x: i64,"));
1438        assert!(output.contains("y: i64,"));
1439        assert!(output.contains("Point { x: 10_i64, y: 20_i64 }"));
1440        assert!(output.contains("p.x"));
1441    }
1442
1443    #[test]
1444    fn generate_enum_declaration() {
1445        let source = r#"
1446            enum Status {
1447                Active,
1448                Inactive,
1449                Pending,
1450            }
1451            agent Main {
1452                on start {
1453                    emit(0);
1454                }
1455            }
1456            run Main;
1457        "#;
1458
1459        let output = generate_source(source);
1460        assert!(output.contains("#[derive(Debug, Clone, Copy, PartialEq, Eq)]"));
1461        assert!(output.contains("enum Status {"));
1462        assert!(output.contains("Active,"));
1463        assert!(output.contains("Inactive,"));
1464        assert!(output.contains("Pending,"));
1465    }
1466
1467    #[test]
1468    fn generate_const_declaration() {
1469        let source = r#"
1470            const MAX_SIZE: Int = 100;
1471            const GREETING: String = "Hello";
1472            agent Main {
1473                on start {
1474                    emit(MAX_SIZE);
1475                }
1476            }
1477            run Main;
1478        "#;
1479
1480        let output = generate_source(source);
1481        assert!(output.contains("const MAX_SIZE: i64 = 100_i64;"));
1482        assert!(output.contains("const GREETING: String = \"Hello\".to_string();"));
1483    }
1484
1485    #[test]
1486    fn generate_match_expression() {
1487        let source = r#"
1488            enum Status {
1489                Active,
1490                Inactive,
1491            }
1492            fn check_status(s: Status) -> Int {
1493                return match s {
1494                    Active => 1,
1495                    Inactive => 0,
1496                };
1497            }
1498            agent Main {
1499                on start {
1500                    emit(0);
1501                }
1502            }
1503            run Main;
1504        "#;
1505
1506        let output = generate_source(source);
1507        assert!(output.contains("match s {"));
1508        assert!(output.contains("Active => 1_i64,"));
1509        assert!(output.contains("Inactive => 0_i64,"));
1510    }
1511
1512    // =========================================================================
1513    // RFC-0007: Error handling codegen tests
1514    // =========================================================================
1515
1516    #[test]
1517    fn generate_fallible_function() {
1518        let source = r#"
1519            fn get_data(url: String) -> String fails {
1520                return url;
1521            }
1522            agent Main {
1523                on start { emit(0); }
1524            }
1525            run Main;
1526        "#;
1527
1528        let output = generate_source(source);
1529        // Fallible function should return SageResult<T>
1530        assert!(output.contains("fn get_data(url: String) -> SageResult<String>"));
1531    }
1532
1533    #[test]
1534    fn generate_try_expression() {
1535        let source = r#"
1536            fn fallible() -> Int fails { return 42; }
1537            fn caller() -> Int fails {
1538                let x = try fallible();
1539                return x;
1540            }
1541            agent Main {
1542                on start { emit(0); }
1543            }
1544            run Main;
1545        "#;
1546
1547        let output = generate_source(source);
1548        // try should generate ? operator
1549        assert!(output.contains("fallible()?"));
1550    }
1551
1552    #[test]
1553    fn generate_catch_expression() {
1554        let source = r#"
1555            fn fallible() -> Int fails { return 42; }
1556            agent Main {
1557                on start {
1558                    let x = fallible() catch { 0 };
1559                    emit(x);
1560                }
1561            }
1562            run Main;
1563        "#;
1564
1565        let output = generate_source(source);
1566        // catch should generate match expression
1567        assert!(output.contains("match fallible()"));
1568        assert!(output.contains("Ok(__val) => __val"));
1569        assert!(output.contains("Err(_) => 0_i64"));
1570    }
1571
1572    #[test]
1573    fn generate_catch_with_binding() {
1574        let source = r#"
1575            fn fallible() -> Int fails { return 42; }
1576            agent Main {
1577                on start {
1578                    let x = fallible() catch(e) { 0 };
1579                    emit(x);
1580                }
1581            }
1582            run Main;
1583        "#;
1584
1585        let output = generate_source(source);
1586        // catch with binding should capture the error
1587        assert!(output.contains("Err(e) => 0_i64"));
1588    }
1589
1590    #[test]
1591    fn generate_on_error_handler() {
1592        let source = r#"
1593            agent Main {
1594                on start {
1595                    emit(0);
1596                }
1597                on error(e) {
1598                    emit(1);
1599                }
1600            }
1601            run Main;
1602        "#;
1603
1604        let output = generate_source(source);
1605        // Should generate on_error method with &mut ctx
1606        assert!(output.contains("async fn on_error(self, e: SageError, ctx: &mut AgentContext"));
1607        // Main should dispatch to on_error on failure with &mut ctx
1608        assert!(output.contains(".on_error(e, &mut ctx)"));
1609    }
1610
1611    // =========================================================================
1612    // RFC-0011: Tool support codegen tests
1613    // =========================================================================
1614
1615    #[test]
1616    fn generate_agent_with_tool_use() {
1617        let source = r#"
1618            agent Fetcher {
1619                use Http
1620
1621                on start {
1622                    let r = Http.get("https://example.com");
1623                    emit(0);
1624                }
1625            }
1626            run Fetcher;
1627        "#;
1628
1629        let output = generate_source(source);
1630        // Should generate struct with http field
1631        assert!(output.contains("struct Fetcher {"));
1632        assert!(output.contains("http: HttpClient,"));
1633        // Should initialize HttpClient in main
1634        assert!(output.contains("http: HttpClient::from_env()"));
1635        // Should generate tool call
1636        assert!(output.contains("self.http.get("));
1637    }
1638
1639    #[test]
1640    fn generate_tool_call_expression() {
1641        let source = r#"
1642            agent Fetcher {
1643                use Http
1644
1645                on start {
1646                    let response = Http.get("https://httpbin.org/get");
1647                    emit(0);
1648                }
1649            }
1650            run Fetcher;
1651        "#;
1652
1653        let output = generate_source(source);
1654        // Tool call should generate self.http.get(...).await (no ?, handled by try/catch)
1655        assert!(output.contains("self.http.get(\"https://httpbin.org/get\".to_string()).await"));
1656    }
1657}