Skip to main content

sage_codegen/
generator.rs

1//! Main code generator.
2
3use crate::emit::Emitter;
4use sage_parser::{
5    AgentDecl, BinOp, Block, EventKind, Expr, FnDecl, Literal, Program, Stmt, StringPart,
6    UnaryOp,
7};
8use sage_types::TypeExpr;
9
10/// Generated Rust project files.
11pub struct GeneratedProject {
12    /// The main.rs content.
13    pub main_rs: String,
14    /// The Cargo.toml content.
15    pub cargo_toml: String,
16}
17
18/// Generate Rust code from a Sage program.
19pub fn generate(program: &Program, project_name: &str) -> GeneratedProject {
20    let mut gen = Generator::new();
21    let main_rs = gen.generate_program(program);
22    let cargo_toml = gen.generate_cargo_toml(project_name);
23    GeneratedProject { main_rs, cargo_toml }
24}
25
26struct Generator {
27    emit: Emitter,
28}
29
30impl Generator {
31    fn new() -> Self {
32        Self {
33            emit: Emitter::new(),
34        }
35    }
36
37    fn generate_program(&mut self, program: &Program) -> String {
38        // Prelude
39        self.emit.writeln("//! Generated by Sage compiler. Do not edit.");
40        self.emit.blank_line();
41        self.emit.writeln("use sage_runtime::prelude::*;");
42        self.emit.blank_line();
43
44        // Functions
45        for func in &program.functions {
46            self.generate_function(func);
47            self.emit.blank_line();
48        }
49
50        // Agents
51        for agent in &program.agents {
52            self.generate_agent(agent);
53            self.emit.blank_line();
54        }
55
56        // Entry point
57        self.generate_main(&program.run_agent.name);
58
59        std::mem::take(&mut self.emit).finish()
60    }
61
62    fn generate_cargo_toml(&self, name: &str) -> String {
63        // Use a relative path that works from target/sage/<project>/
64        // This assumes the standard project layout
65        format!(
66            r#"[package]
67name = "{name}"
68version = "0.1.0"
69edition = "2021"
70
71[dependencies]
72sage-runtime = {{ path = "../../../crates/sage-runtime" }}
73tokio = {{ version = "1", features = ["full"] }}
74serde = {{ version = "1", features = ["derive"] }}
75serde_json = "1"
76
77# Standalone project, not part of parent workspace
78[workspace]
79"#
80        )
81    }
82
83    fn generate_function(&mut self, func: &FnDecl) {
84        // Function signature
85        self.emit.write("fn ");
86        self.emit.write(&func.name.name);
87        self.emit.write("(");
88
89        for (i, param) in func.params.iter().enumerate() {
90            if i > 0 {
91                self.emit.write(", ");
92            }
93            self.emit.write(&param.name.name);
94            self.emit.write(": ");
95            self.emit_type(&param.ty);
96        }
97
98        self.emit.write(") -> ");
99        self.emit_type(&func.return_ty);
100        self.emit.write(" ");
101        self.generate_block(&func.body);
102    }
103
104    fn generate_agent(&mut self, agent: &AgentDecl) {
105        let name = &agent.name.name;
106
107        // Struct definition
108        self.emit.write("struct ");
109        self.emit.write(name);
110        if agent.beliefs.is_empty() {
111            self.emit.writeln(";");
112        } else {
113            self.emit.writeln(" {");
114            self.emit.indent();
115            for belief in &agent.beliefs {
116                self.emit.write(&belief.name.name);
117                self.emit.write(": ");
118                self.emit_type(&belief.ty);
119                self.emit.writeln(",");
120            }
121            self.emit.dedent();
122            self.emit.writeln("}");
123        }
124        self.emit.blank_line();
125
126        // Find the output type from the start handler
127        let output_type = self.infer_agent_output_type(agent);
128
129        // Impl block
130        self.emit.write("impl ");
131        self.emit.write(name);
132        self.emit.writeln(" {");
133        self.emit.indent();
134
135        // Generate handlers
136        for handler in &agent.handlers {
137            if let EventKind::Start = &handler.event {
138                self.emit.write("async fn on_start(self, ctx: AgentContext<");
139                self.emit.write(&output_type);
140                self.emit.write(">) -> SageResult<");
141                self.emit.write(&output_type);
142                self.emit.writeln("> {");
143                self.emit.indent();
144                self.generate_block_contents(&handler.body);
145                self.emit.dedent();
146                self.emit.writeln("}");
147            }
148        }
149
150        self.emit.dedent();
151        self.emit.writeln("}");
152    }
153
154    fn generate_main(&mut self, entry_agent: &str) {
155        self.emit.writeln("#[tokio::main]");
156        self.emit.writeln("async fn main() -> Result<(), Box<dyn std::error::Error>> {");
157        self.emit.indent();
158
159        self.emit.write("let handle = sage_runtime::spawn(|ctx| ");
160        self.emit.write(entry_agent);
161        self.emit.writeln(".on_start(ctx));");
162        self.emit.writeln("let result = handle.result().await?;");
163        self.emit.writeln("println!(\"{:?}\", result);");
164        self.emit.writeln("Ok(())");
165
166        self.emit.dedent();
167        self.emit.writeln("}");
168    }
169
170    fn generate_block(&mut self, block: &Block) {
171        self.emit.open_brace();
172        self.generate_block_contents(block);
173        self.emit.close_brace();
174    }
175
176    fn generate_block_inline(&mut self, block: &Block) {
177        self.emit.open_brace();
178        self.generate_block_contents(block);
179        self.emit.close_brace_inline();
180    }
181
182    fn generate_block_contents(&mut self, block: &Block) {
183        for stmt in &block.stmts {
184            self.generate_stmt(stmt);
185        }
186    }
187
188    fn generate_stmt(&mut self, stmt: &Stmt) {
189        match stmt {
190            Stmt::Let { name, ty, value, .. } => {
191                self.emit.write("let ");
192                if ty.is_some() {
193                    self.emit.write(&name.name);
194                    self.emit.write(": ");
195                    self.emit_type(ty.as_ref().unwrap());
196                } else {
197                    self.emit.write(&name.name);
198                }
199                self.emit.write(" = ");
200                self.generate_expr(value);
201                self.emit.writeln(";");
202            }
203
204            Stmt::Assign { name, value, .. } => {
205                self.emit.write(&name.name);
206                self.emit.write(" = ");
207                self.generate_expr(value);
208                self.emit.writeln(";");
209            }
210
211            Stmt::Return { value, .. } => {
212                self.emit.write("return ");
213                if let Some(expr) = value {
214                    self.generate_expr(expr);
215                }
216                self.emit.writeln(";");
217            }
218
219            Stmt::If {
220                condition,
221                then_block,
222                else_block,
223                ..
224            } => {
225                self.emit.write("if ");
226                self.generate_expr(condition);
227                self.emit.write(" ");
228                if else_block.is_some() {
229                    self.generate_block_inline(then_block);
230                    self.emit.write(" else ");
231                    match else_block.as_ref().unwrap() {
232                        sage_parser::ElseBranch::Block(block) => {
233                            self.generate_block(block);
234                        }
235                        sage_parser::ElseBranch::ElseIf(stmt) => {
236                            self.generate_stmt(stmt);
237                        }
238                    }
239                } else {
240                    self.generate_block(then_block);
241                }
242            }
243
244            Stmt::For { var, iter, body, .. } => {
245                self.emit.write("for ");
246                self.emit.write(&var.name);
247                self.emit.write(" in ");
248                self.generate_expr(iter);
249                self.emit.write(" ");
250                self.generate_block(body);
251            }
252
253            Stmt::While {
254                condition, body, ..
255            } => {
256                self.emit.write("while ");
257                self.generate_expr(condition);
258                self.emit.write(" ");
259                self.generate_block(body);
260            }
261
262            Stmt::Expr { expr, .. } => {
263                // Handle emit specially
264                if let Expr::Emit { value, .. } = expr {
265                    self.emit.write("return ctx.emit(");
266                    self.generate_expr(value);
267                    self.emit.writeln(");");
268                } else {
269                    self.generate_expr(expr);
270                    self.emit.writeln(";");
271                }
272            }
273        }
274    }
275
276    fn generate_expr(&mut self, expr: &Expr) {
277        match expr {
278            Expr::Literal { value, .. } => {
279                self.emit_literal(value);
280            }
281
282            Expr::Var { name, .. } => {
283                self.emit.write(&name.name);
284            }
285
286            Expr::Binary {
287                op, left, right, ..
288            } => {
289                // Handle string concatenation specially
290                if matches!(op, BinOp::Concat) {
291                    self.emit.write("format!(\"{}{}\", ");
292                    self.generate_expr(left);
293                    self.emit.write(", ");
294                    self.generate_expr(right);
295                    self.emit.write(")");
296                } else {
297                    self.emit.write("(");
298                    self.generate_expr(left);
299                    self.emit.write(" ");
300                    self.emit_binop(op);
301                    self.emit.write(" ");
302                    self.generate_expr(right);
303                    self.emit.write(")");
304                }
305            }
306
307            Expr::Unary { op, operand, .. } => {
308                self.emit_unaryop(op);
309                self.generate_expr(operand);
310            }
311
312            Expr::Call { name, args, .. } => {
313                let fn_name = &name.name;
314
315                // Handle builtins
316                match fn_name.as_str() {
317                    "print" => {
318                        self.emit.write("println!(\"{}\", ");
319                        self.generate_expr(&args[0]);
320                        self.emit.write(")");
321                    }
322                    "str" => {
323                        self.generate_expr(&args[0]);
324                        self.emit.write(".to_string()");
325                    }
326                    "len" => {
327                        self.generate_expr(&args[0]);
328                        self.emit.write(".len() as i64");
329                    }
330                    _ => {
331                        self.emit.write(fn_name);
332                        self.emit.write("(");
333                        for (i, arg) in args.iter().enumerate() {
334                            if i > 0 {
335                                self.emit.write(", ");
336                            }
337                            self.generate_expr(arg);
338                        }
339                        self.emit.write(")");
340                    }
341                }
342            }
343
344            Expr::SelfField { field, .. } => {
345                self.emit.write("self.");
346                self.emit.write(&field.name);
347            }
348
349            Expr::SelfMethodCall { method, args, .. } => {
350                self.emit.write("self.");
351                self.emit.write(&method.name);
352                self.emit.write("(");
353                for (i, arg) in args.iter().enumerate() {
354                    if i > 0 {
355                        self.emit.write(", ");
356                    }
357                    self.generate_expr(arg);
358                }
359                self.emit.write(")");
360            }
361
362            Expr::List { elements, .. } => {
363                self.emit.write("vec![");
364                for (i, elem) in elements.iter().enumerate() {
365                    if i > 0 {
366                        self.emit.write(", ");
367                    }
368                    self.generate_expr(elem);
369                }
370                self.emit.write("]");
371            }
372
373            Expr::Paren { inner, .. } => {
374                self.emit.write("(");
375                self.generate_expr(inner);
376                self.emit.write(")");
377            }
378
379            Expr::Infer { template, .. } => {
380                self.emit.write("ctx.infer_string(&");
381                self.emit_string_template(template);
382                self.emit.write(").await?");
383            }
384
385            Expr::Spawn { agent, fields, .. } => {
386                self.emit.write("sage_runtime::spawn(|ctx| ");
387                self.emit.write(&agent.name);
388                if fields.is_empty() {
389                    self.emit.write(".on_start(ctx))");
390                } else {
391                    self.emit.write(" { ");
392                    for (i, field) in fields.iter().enumerate() {
393                        if i > 0 {
394                            self.emit.write(", ");
395                        }
396                        self.emit.write(&field.name.name);
397                        self.emit.write(": ");
398                        self.generate_expr(&field.value);
399                    }
400                    self.emit.write(" }.on_start(ctx))");
401                }
402            }
403
404            Expr::Await { handle, .. } => {
405                self.generate_expr(handle);
406                self.emit.write(".result().await?");
407            }
408
409            Expr::Send { handle, message, .. } => {
410                self.generate_expr(handle);
411                self.emit.write(".send(sage_runtime::Message::new(");
412                self.generate_expr(message);
413                self.emit.write(")?).await?");
414            }
415
416            Expr::Emit { value, .. } => {
417                self.emit.write("ctx.emit(");
418                self.generate_expr(value);
419                self.emit.write(")");
420            }
421
422            Expr::StringInterp { template, .. } => {
423                self.emit_string_template(template);
424            }
425        }
426    }
427
428    fn emit_literal(&mut self, lit: &Literal) {
429        match lit {
430            Literal::Int(n) => {
431                self.emit.write(&format!("{n}_i64"));
432            }
433            Literal::Float(f) => {
434                self.emit.write(&format!("{f}_f64"));
435            }
436            Literal::Bool(b) => {
437                self.emit.write(if *b { "true" } else { "false" });
438            }
439            Literal::String(s) => {
440                // Escape the string for Rust
441                self.emit.write("\"");
442                for c in s.chars() {
443                    match c {
444                        '"' => self.emit.write_raw("\\\""),
445                        '\\' => self.emit.write_raw("\\\\"),
446                        '\n' => self.emit.write_raw("\\n"),
447                        '\r' => self.emit.write_raw("\\r"),
448                        '\t' => self.emit.write_raw("\\t"),
449                        _ => self.emit.write_raw(&c.to_string()),
450                    }
451                }
452                self.emit.write("\".to_string()");
453            }
454        }
455    }
456
457    fn emit_string_template(&mut self, template: &sage_parser::StringTemplate) {
458        if !template.has_interpolations() {
459            // Simple string literal
460            if let Some(StringPart::Literal(s)) = template.parts.first() {
461                self.emit.write("\"");
462                self.emit.write_raw(s);
463                self.emit.write("\".to_string()");
464            }
465            return;
466        }
467
468        // Build format string and args
469        self.emit.write("format!(\"");
470        for part in &template.parts {
471            match part {
472                StringPart::Literal(s) => {
473                    // Escape braces for format string
474                    let escaped = s.replace('{', "{{").replace('}', "}}");
475                    self.emit.write_raw(&escaped);
476                }
477                StringPart::Interpolation(_) => {
478                    self.emit.write_raw("{}");
479                }
480            }
481        }
482        self.emit.write("\"");
483
484        // Add the interpolation args
485        for part in &template.parts {
486            if let StringPart::Interpolation(ident) = part {
487                self.emit.write(", ");
488                self.emit.write(&ident.name);
489            }
490        }
491        self.emit.write(")");
492    }
493
494    fn emit_type(&mut self, ty: &TypeExpr) {
495        match ty {
496            TypeExpr::Int => self.emit.write("i64"),
497            TypeExpr::Float => self.emit.write("f64"),
498            TypeExpr::Bool => self.emit.write("bool"),
499            TypeExpr::String => self.emit.write("String"),
500            TypeExpr::Unit => self.emit.write("()"),
501            TypeExpr::List(inner) => {
502                self.emit.write("Vec<");
503                self.emit_type(inner);
504                self.emit.write(">");
505            }
506            TypeExpr::Option(inner) => {
507                self.emit.write("Option<");
508                self.emit_type(inner);
509                self.emit.write(">");
510            }
511            TypeExpr::Inferred(inner) => {
512                // Inferred<T> just becomes T at runtime
513                self.emit_type(inner);
514            }
515            TypeExpr::Agent(agent_name) => {
516                // Agent handles use the agent's output type, but we don't know it here
517                // For now, just use a generic output type
518                self.emit.write("AgentHandle<");
519                self.emit.write(&agent_name.name);
520                self.emit.write("Output>");
521            }
522            TypeExpr::Named(name) => {
523                self.emit.write(&name.name);
524            }
525        }
526    }
527
528    fn emit_binop(&mut self, op: &BinOp) {
529        let s = match op {
530            BinOp::Add => "+",
531            BinOp::Sub => "-",
532            BinOp::Mul => "*",
533            BinOp::Div => "/",
534            BinOp::Eq => "==",
535            BinOp::Ne => "!=",
536            BinOp::Lt => "<",
537            BinOp::Gt => ">",
538            BinOp::Le => "<=",
539            BinOp::Ge => ">=",
540            BinOp::And => "&&",
541            BinOp::Or => "||",
542            BinOp::Concat => "++", // Handled specially above
543        };
544        self.emit.write(s);
545    }
546
547    fn emit_unaryop(&mut self, op: &UnaryOp) {
548        let s = match op {
549            UnaryOp::Neg => "-",
550            UnaryOp::Not => "!",
551        };
552        self.emit.write(s);
553    }
554
555    fn infer_agent_output_type(&self, agent: &AgentDecl) -> String {
556        // Look for emit expression in start handler to infer return type
557        // For now, default to i64
558        for handler in &agent.handlers {
559            if let EventKind::Start = &handler.event {
560                if let Some(ty) = self.find_emit_type(&handler.body) {
561                    return ty;
562                }
563            }
564        }
565        "i64".to_string()
566    }
567
568    fn find_emit_type(&self, block: &Block) -> Option<String> {
569        for stmt in &block.stmts {
570            if let Stmt::Expr { expr, .. } = stmt {
571                if let Expr::Emit { value, .. } = expr {
572                    return Some(self.infer_expr_type(value));
573                }
574            }
575            // Check nested blocks
576            if let Stmt::If {
577                then_block,
578                else_block,
579                ..
580            } = stmt
581            {
582                if let Some(ty) = self.find_emit_type(then_block) {
583                    return Some(ty);
584                }
585                if let Some(else_branch) = else_block {
586                    if let sage_parser::ElseBranch::Block(block) = else_branch {
587                        if let Some(ty) = self.find_emit_type(block) {
588                            return Some(ty);
589                        }
590                    }
591                }
592            }
593        }
594        None
595    }
596
597    fn infer_expr_type(&self, expr: &Expr) -> String {
598        match expr {
599            Expr::Literal { value, .. } => match value {
600                Literal::Int(_) => "i64".to_string(),
601                Literal::Float(_) => "f64".to_string(),
602                Literal::Bool(_) => "bool".to_string(),
603                Literal::String(_) => "String".to_string(),
604            },
605            Expr::Var { .. } => "i64".to_string(), // Conservative default
606            Expr::Binary { op, .. } => {
607                if matches!(
608                    op,
609                    BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge
610                ) {
611                    "bool".to_string()
612                } else if matches!(op, BinOp::Concat) {
613                    "String".to_string()
614                } else {
615                    "i64".to_string()
616                }
617            }
618            Expr::Infer { .. } | Expr::StringInterp { .. } => "String".to_string(),
619            Expr::Call { name, .. } if name.name == "str" => "String".to_string(),
620            Expr::Call { name, .. } if name.name == "len" => "i64".to_string(),
621            _ => "i64".to_string(),
622        }
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629    use sage_lexer::lex;
630    use sage_parser::parse;
631    use std::sync::Arc;
632
633    fn generate_source(source: &str) -> String {
634        let lex_result = lex(source).expect("lexing failed");
635        let source_arc: Arc<str> = Arc::from(source);
636        let (program, errors) = parse(lex_result.tokens(), source_arc);
637        assert!(errors.is_empty(), "parse errors: {errors:?}");
638        let program = program.expect("should parse");
639        generate(&program, "test").main_rs
640    }
641
642    #[test]
643    fn generate_minimal_program() {
644        let source = r#"
645            agent Main {
646                on start {
647                    emit(42);
648                }
649            }
650            run Main;
651        "#;
652
653        let output = generate_source(source);
654        assert!(output.contains("struct Main;"));
655        assert!(output.contains("async fn on_start"));
656        assert!(output.contains("ctx.emit(42_i64)"));
657        assert!(output.contains("#[tokio::main]"));
658    }
659
660    #[test]
661    fn generate_function() {
662        let source = r#"
663            fn add(a: Int, b: Int) -> Int {
664                return a + b;
665            }
666            agent Main {
667                on start {
668                    emit(add(1, 2));
669                }
670            }
671            run Main;
672        "#;
673
674        let output = generate_source(source);
675        assert!(output.contains("fn add(a: i64, b: i64) -> i64"));
676        assert!(output.contains("return (a + b);"));
677    }
678
679    #[test]
680    fn generate_agent_with_beliefs() {
681        let source = r#"
682            agent Worker {
683                belief value: Int
684
685                on start {
686                    emit(self.value * 2);
687                }
688            }
689            agent Main {
690                on start {
691                    emit(0);
692                }
693            }
694            run Main;
695        "#;
696
697        let output = generate_source(source);
698        assert!(output.contains("struct Worker {"));
699        assert!(output.contains("value: i64,"));
700        assert!(output.contains("self.value"));
701    }
702
703    #[test]
704    fn generate_string_interpolation() {
705        let source = r#"
706            agent Main {
707                on start {
708                    let name = "World";
709                    let msg = "Hello, {name}!";
710                    print(msg);
711                    emit(0);
712                }
713            }
714            run Main;
715        "#;
716
717        let output = generate_source(source);
718        assert!(output.contains("format!(\"Hello, {}!\", name)"));
719    }
720
721    #[test]
722    fn generate_control_flow() {
723        let source = r#"
724            agent Main {
725                on start {
726                    let x = 10;
727                    if x > 5 {
728                        emit(1);
729                    } else {
730                        emit(0);
731                    }
732                }
733            }
734            run Main;
735        "#;
736
737        let output = generate_source(source);
738        assert!(output.contains("if (x > 5_i64)"), "output:\n{output}");
739        // else is on the same line after close brace
740        assert!(output.contains("else"), "output:\n{output}");
741    }
742
743    #[test]
744    fn generate_loops() {
745        let source = r#"
746            agent Main {
747                on start {
748                    for x in [1, 2, 3] {
749                        print(str(x));
750                    }
751                    let n = 0;
752                    while n < 5 {
753                        n = n + 1;
754                    }
755                    emit(n);
756                }
757            }
758            run Main;
759        "#;
760
761        let output = generate_source(source);
762        assert!(output.contains("for x in vec![1_i64, 2_i64, 3_i64]"));
763        assert!(output.contains("while (n < 5_i64)"));
764    }
765}