1use crate::emit::Emitter;
4use sage_loader::ModuleTree;
5use sage_parser::{
6 AgentDecl, BinOp, Block, ConstDecl, EnumDecl, EventKind, Expr, FnDecl, Literal, Program,
7 RecordDecl, Stmt, StringPart, UnaryOp,
8};
9use sage_types::TypeExpr;
10
11pub struct GeneratedProject {
13 pub main_rs: String,
15 pub cargo_toml: String,
17}
18
19pub fn generate(program: &Program, project_name: &str) -> GeneratedProject {
21 let mut gen = Generator::new();
22 let main_rs = gen.generate_program(program);
23 let cargo_toml = gen.generate_cargo_toml(project_name);
24 GeneratedProject {
25 main_rs,
26 cargo_toml,
27 }
28}
29
30pub fn generate_module_tree(tree: &ModuleTree, project_name: &str) -> GeneratedProject {
35 let mut gen = Generator::new();
36 let main_rs = gen.generate_module_tree(tree);
37 let cargo_toml = gen.generate_cargo_toml(project_name);
38 GeneratedProject {
39 main_rs,
40 cargo_toml,
41 }
42}
43
44struct Generator {
45 emit: Emitter,
46}
47
48impl Generator {
49 fn new() -> Self {
50 Self {
51 emit: Emitter::new(),
52 }
53 }
54
55 fn generate_program(&mut self, program: &Program) -> String {
56 self.emit
58 .writeln("//! Generated by Sage compiler. Do not edit.");
59 self.emit.blank_line();
60 self.emit.writeln("use sage_runtime::prelude::*;");
61 self.emit.blank_line();
62
63 for const_decl in &program.consts {
65 self.generate_const(const_decl);
66 self.emit.blank_line();
67 }
68
69 for enum_decl in &program.enums {
71 self.generate_enum(enum_decl);
72 self.emit.blank_line();
73 }
74
75 for record in &program.records {
77 self.generate_record(record);
78 self.emit.blank_line();
79 }
80
81 for func in &program.functions {
83 self.generate_function(func);
84 self.emit.blank_line();
85 }
86
87 for agent in &program.agents {
89 self.generate_agent(agent);
90 self.emit.blank_line();
91 }
92
93 if let Some(run_agent) = &program.run_agent {
95 let has_error_handler = program
97 .agents
98 .iter()
99 .find(|a| a.name.name == run_agent.name)
100 .map_or(false, |a| {
101 a.handlers
102 .iter()
103 .any(|h| matches!(h.event, EventKind::Error { .. }))
104 });
105 self.generate_main(&run_agent.name, has_error_handler);
106 }
107
108 std::mem::take(&mut self.emit).finish()
109 }
110
111 fn generate_module_tree(&mut self, tree: &ModuleTree) -> String {
112 self.emit
114 .writeln("//! Generated by Sage compiler. Do not edit.");
115 self.emit.blank_line();
116 self.emit.writeln("use sage_runtime::prelude::*;");
117 self.emit.blank_line();
118
119 for (path, module) in &tree.modules {
125 if path != &tree.root {
126 self.emit.write("// Module: ");
127 if path.is_empty() {
128 self.emit.writeln("(root)");
129 } else {
130 self.emit.writeln(&path.join("::"));
131 }
132
133 for const_decl in &module.program.consts {
134 self.generate_const(const_decl);
135 self.emit.blank_line();
136 }
137
138 for enum_decl in &module.program.enums {
139 self.generate_enum(enum_decl);
140 self.emit.blank_line();
141 }
142
143 for record in &module.program.records {
144 self.generate_record(record);
145 self.emit.blank_line();
146 }
147
148 for func in &module.program.functions {
149 self.generate_function(func);
150 self.emit.blank_line();
151 }
152
153 for agent in &module.program.agents {
154 self.generate_agent(agent);
155 self.emit.blank_line();
156 }
157 }
158 }
159
160 if let Some(root_module) = tree.modules.get(&tree.root) {
162 self.emit.writeln("// Root module");
163
164 for const_decl in &root_module.program.consts {
165 self.generate_const(const_decl);
166 self.emit.blank_line();
167 }
168
169 for enum_decl in &root_module.program.enums {
170 self.generate_enum(enum_decl);
171 self.emit.blank_line();
172 }
173
174 for record in &root_module.program.records {
175 self.generate_record(record);
176 self.emit.blank_line();
177 }
178
179 for func in &root_module.program.functions {
180 self.generate_function(func);
181 self.emit.blank_line();
182 }
183
184 for agent in &root_module.program.agents {
185 self.generate_agent(agent);
186 self.emit.blank_line();
187 }
188
189 if let Some(run_agent) = &root_module.program.run_agent {
191 let has_error_handler = root_module
193 .program
194 .agents
195 .iter()
196 .find(|a| a.name.name == run_agent.name)
197 .map_or(false, |a| {
198 a.handlers
199 .iter()
200 .any(|h| matches!(h.event, EventKind::Error { .. }))
201 });
202 self.generate_main(&run_agent.name, has_error_handler);
203 }
204 }
205
206 std::mem::take(&mut self.emit).finish()
207 }
208
209 fn generate_cargo_toml(&self, name: &str) -> String {
210 format!(
213 r#"[package]
214name = "{name}"
215version = "0.1.0"
216edition = "2021"
217
218[dependencies]
219sage-runtime = {{ path = "../../../crates/sage-runtime" }}
220tokio = {{ version = "1", features = ["full"] }}
221serde = {{ version = "1", features = ["derive"] }}
222serde_json = "1"
223
224# Standalone project, not part of parent workspace
225[workspace]
226"#
227 )
228 }
229
230 fn generate_const(&mut self, const_decl: &ConstDecl) {
231 if const_decl.is_pub {
232 self.emit.write("pub ");
233 }
234 self.emit.write("const ");
235 self.emit.write(&const_decl.name.name);
236 self.emit.write(": ");
237 self.emit_type(&const_decl.ty);
238 self.emit.write(" = ");
239 self.generate_expr(&const_decl.value);
240 self.emit.writeln(";");
241 }
242
243 fn generate_enum(&mut self, enum_decl: &EnumDecl) {
244 if enum_decl.is_pub {
245 self.emit.write("pub ");
246 }
247 self.emit
248 .writeln("#[derive(Debug, Clone, Copy, PartialEq, Eq)]");
249 self.emit.write("enum ");
250 self.emit.write(&enum_decl.name.name);
251 self.emit.writeln(" {");
252 self.emit.indent();
253 for variant in &enum_decl.variants {
254 self.emit.write(&variant.name);
255 self.emit.writeln(",");
256 }
257 self.emit.dedent();
258 self.emit.writeln("}");
259 }
260
261 fn generate_record(&mut self, record: &RecordDecl) {
262 if record.is_pub {
263 self.emit.write("pub ");
264 }
265 self.emit.writeln("#[derive(Debug, Clone)]");
266 self.emit.write("struct ");
267 self.emit.write(&record.name.name);
268 self.emit.writeln(" {");
269 self.emit.indent();
270 for field in &record.fields {
271 self.emit.write(&field.name.name);
272 self.emit.write(": ");
273 self.emit_type(&field.ty);
274 self.emit.writeln(",");
275 }
276 self.emit.dedent();
277 self.emit.writeln("}");
278 }
279
280 fn generate_function(&mut self, func: &FnDecl) {
281 if func.is_pub {
283 self.emit.write("pub ");
284 }
285 self.emit.write("fn ");
286 self.emit.write(&func.name.name);
287 self.emit.write("(");
288
289 for (i, param) in func.params.iter().enumerate() {
290 if i > 0 {
291 self.emit.write(", ");
292 }
293 self.emit.write(¶m.name.name);
294 self.emit.write(": ");
295 self.emit_type(¶m.ty);
296 }
297
298 self.emit.write(") -> ");
299
300 if func.is_fallible {
302 self.emit.write("SageResult<");
303 self.emit_type(&func.return_ty);
304 self.emit.write(">");
305 } else {
306 self.emit_type(&func.return_ty);
307 }
308
309 self.emit.write(" ");
310 self.generate_block(&func.body);
311 }
312
313 fn generate_agent(&mut self, agent: &AgentDecl) {
314 let name = &agent.name.name;
315
316 if agent.is_pub {
318 self.emit.write("pub ");
319 }
320 self.emit.write("struct ");
321 self.emit.write(name);
322 if agent.beliefs.is_empty() {
323 self.emit.writeln(";");
324 } else {
325 self.emit.writeln(" {");
326 self.emit.indent();
327 for belief in &agent.beliefs {
328 self.emit.write(&belief.name.name);
329 self.emit.write(": ");
330 self.emit_type(&belief.ty);
331 self.emit.writeln(",");
332 }
333 self.emit.dedent();
334 self.emit.writeln("}");
335 }
336 self.emit.blank_line();
337
338 let output_type = self.infer_agent_output_type(agent);
340
341 self.emit.write("impl ");
343 self.emit.write(name);
344 self.emit.writeln(" {");
345 self.emit.indent();
346
347 for handler in &agent.handlers {
349 match &handler.event {
350 EventKind::Start => {
351 self.emit
352 .write("async fn on_start(self, ctx: AgentContext<");
353 self.emit.write(&output_type);
354 self.emit.write(">) -> SageResult<");
355 self.emit.write(&output_type);
356 self.emit.writeln("> {");
357 self.emit.indent();
358 self.generate_block_contents(&handler.body);
359 self.emit.dedent();
360 self.emit.writeln("}");
361 }
362
363 EventKind::Error { param_name } => {
365 self.emit.write("async fn on_error(self, ");
366 self.emit.write(¶m_name.name);
367 self.emit.write(": SageError, ctx: AgentContext<");
368 self.emit.write(&output_type);
369 self.emit.write(">) -> SageResult<");
370 self.emit.write(&output_type);
371 self.emit.writeln("> {");
372 self.emit.indent();
373 self.generate_block_contents(&handler.body);
374 self.emit.dedent();
375 self.emit.writeln("}");
376 }
377
378 _ => {}
380 }
381 }
382
383 self.emit.dedent();
384 self.emit.writeln("}");
385 }
386
387 fn generate_main(&mut self, entry_agent: &str, has_error_handler: bool) {
388 self.emit.writeln("#[tokio::main]");
389 self.emit
390 .writeln("async fn main() -> Result<(), Box<dyn std::error::Error>> {");
391 self.emit.indent();
392
393 if has_error_handler {
394 self.emit
396 .writeln("let handle = sage_runtime::spawn(|ctx| async move {");
397 self.emit.indent();
398 self.emit.write("match ");
399 self.emit.write(entry_agent);
400 self.emit.writeln(".on_start(ctx.clone()).await {");
401 self.emit.indent();
402 self.emit.writeln("Ok(result) => Ok(result),");
403 self.emit.write("Err(e) => ");
404 self.emit.write(entry_agent);
405 self.emit.writeln(".on_error(e, ctx).await,");
406 self.emit.dedent();
407 self.emit.writeln("}");
408 self.emit.dedent();
409 self.emit.writeln("});");
410 } else {
411 self.emit.write("let handle = sage_runtime::spawn(|ctx| ");
412 self.emit.write(entry_agent);
413 self.emit.writeln(".on_start(ctx));");
414 }
415
416 self.emit.writeln("let result = handle.result().await?;");
417 self.emit.writeln("println!(\"{:?}\", result);");
418 self.emit.writeln("Ok(())");
419
420 self.emit.dedent();
421 self.emit.writeln("}");
422 }
423
424 fn generate_block(&mut self, block: &Block) {
425 self.emit.open_brace();
426 self.generate_block_contents(block);
427 self.emit.close_brace();
428 }
429
430 fn generate_block_inline(&mut self, block: &Block) {
431 self.emit.open_brace();
432 self.generate_block_contents(block);
433 self.emit.close_brace_inline();
434 }
435
436 fn generate_block_contents(&mut self, block: &Block) {
437 for stmt in &block.stmts {
438 self.generate_stmt(stmt);
439 }
440 }
441
442 fn generate_stmt(&mut self, stmt: &Stmt) {
443 match stmt {
444 Stmt::Let {
445 name, ty, value, ..
446 } => {
447 self.emit.write("let ");
448 if ty.is_some() {
449 self.emit.write(&name.name);
450 self.emit.write(": ");
451 self.emit_type(ty.as_ref().unwrap());
452 } else {
453 self.emit.write(&name.name);
454 }
455 self.emit.write(" = ");
456 self.generate_expr(value);
457 self.emit.writeln(";");
458 }
459
460 Stmt::Assign { name, value, .. } => {
461 self.emit.write(&name.name);
462 self.emit.write(" = ");
463 self.generate_expr(value);
464 self.emit.writeln(";");
465 }
466
467 Stmt::Return { value, .. } => {
468 self.emit.write("return ");
469 if let Some(expr) = value {
470 self.generate_expr(expr);
471 }
472 self.emit.writeln(";");
473 }
474
475 Stmt::If {
476 condition,
477 then_block,
478 else_block,
479 ..
480 } => {
481 self.emit.write("if ");
482 self.generate_expr(condition);
483 self.emit.write(" ");
484 if else_block.is_some() {
485 self.generate_block_inline(then_block);
486 self.emit.write(" else ");
487 match else_block.as_ref().unwrap() {
488 sage_parser::ElseBranch::Block(block) => {
489 self.generate_block(block);
490 }
491 sage_parser::ElseBranch::ElseIf(stmt) => {
492 self.generate_stmt(stmt);
493 }
494 }
495 } else {
496 self.generate_block(then_block);
497 }
498 }
499
500 Stmt::For {
501 var, iter, body, ..
502 } => {
503 self.emit.write("for ");
504 self.emit.write(&var.name);
505 self.emit.write(" in ");
506 self.generate_expr(iter);
507 self.emit.write(" ");
508 self.generate_block(body);
509 }
510
511 Stmt::While {
512 condition, body, ..
513 } => {
514 self.emit.write("while ");
515 self.generate_expr(condition);
516 self.emit.write(" ");
517 self.generate_block(body);
518 }
519
520 Stmt::Loop { body, .. } => {
521 self.emit.write("loop ");
522 self.generate_block(body);
523 }
524
525 Stmt::Break { .. } => {
526 self.emit.writeln("break;");
527 }
528
529 Stmt::Expr { expr, .. } => {
530 if let Expr::Emit { value, .. } = expr {
532 self.emit.write("return ctx.emit(");
533 self.generate_expr(value);
534 self.emit.writeln(");");
535 } else {
536 self.generate_expr(expr);
537 self.emit.writeln(";");
538 }
539 }
540 }
541 }
542
543 fn generate_expr(&mut self, expr: &Expr) {
544 match expr {
545 Expr::Literal { value, .. } => {
546 self.emit_literal(value);
547 }
548
549 Expr::Var { name, .. } => {
550 self.emit.write(&name.name);
551 }
552
553 Expr::Binary {
554 op, left, right, ..
555 } => {
556 if matches!(op, BinOp::Concat) {
558 self.emit.write("format!(\"{}{}\", ");
559 self.generate_expr(left);
560 self.emit.write(", ");
561 self.generate_expr(right);
562 self.emit.write(")");
563 } else {
564 self.emit.write("(");
565 self.generate_expr(left);
566 self.emit.write(" ");
567 self.emit_binop(op);
568 self.emit.write(" ");
569 self.generate_expr(right);
570 self.emit.write(")");
571 }
572 }
573
574 Expr::Unary { op, operand, .. } => {
575 self.emit_unaryop(op);
576 self.generate_expr(operand);
577 }
578
579 Expr::Call { name, args, .. } => {
580 let fn_name = &name.name;
581
582 match fn_name.as_str() {
584 "print" => {
585 self.emit.write("println!(\"{}\", ");
586 self.generate_expr(&args[0]);
587 self.emit.write(")");
588 }
589 "str" => {
590 self.generate_expr(&args[0]);
591 self.emit.write(".to_string()");
592 }
593 "len" => {
594 self.generate_expr(&args[0]);
595 self.emit.write(".len() as i64");
596 }
597 _ => {
598 self.emit.write(fn_name);
599 self.emit.write("(");
600 for (i, arg) in args.iter().enumerate() {
601 if i > 0 {
602 self.emit.write(", ");
603 }
604 self.generate_expr(arg);
605 }
606 self.emit.write(")");
607 }
608 }
609 }
610
611 Expr::SelfField { field, .. } => {
612 self.emit.write("self.");
613 self.emit.write(&field.name);
614 }
615
616 Expr::SelfMethodCall { method, args, .. } => {
617 self.emit.write("self.");
618 self.emit.write(&method.name);
619 self.emit.write("(");
620 for (i, arg) in args.iter().enumerate() {
621 if i > 0 {
622 self.emit.write(", ");
623 }
624 self.generate_expr(arg);
625 }
626 self.emit.write(")");
627 }
628
629 Expr::List { elements, .. } => {
630 self.emit.write("vec![");
631 for (i, elem) in elements.iter().enumerate() {
632 if i > 0 {
633 self.emit.write(", ");
634 }
635 self.generate_expr(elem);
636 }
637 self.emit.write("]");
638 }
639
640 Expr::Paren { inner, .. } => {
641 self.emit.write("(");
642 self.generate_expr(inner);
643 self.emit.write(")");
644 }
645
646 Expr::Infer { template, .. } => {
647 self.emit.write("ctx.infer_string(&");
648 self.emit_string_template(template);
649 self.emit.write(").await?");
650 }
651
652 Expr::Spawn { agent, fields, .. } => {
653 self.emit.write("sage_runtime::spawn(|ctx| ");
654 self.emit.write(&agent.name);
655 if fields.is_empty() {
656 self.emit.write(".on_start(ctx))");
657 } else {
658 self.emit.write(" { ");
659 for (i, field) in fields.iter().enumerate() {
660 if i > 0 {
661 self.emit.write(", ");
662 }
663 self.emit.write(&field.name.name);
664 self.emit.write(": ");
665 self.generate_expr(&field.value);
666 }
667 self.emit.write(" }.on_start(ctx))");
668 }
669 }
670
671 Expr::Await { handle, .. } => {
672 self.generate_expr(handle);
673 self.emit.write(".result().await?");
674 }
675
676 Expr::Send {
677 handle, message, ..
678 } => {
679 self.generate_expr(handle);
680 self.emit.write(".send(sage_runtime::Message::new(");
681 self.generate_expr(message);
682 self.emit.write(")?).await?");
683 }
684
685 Expr::Emit { value, .. } => {
686 self.emit.write("ctx.emit(");
687 self.generate_expr(value);
688 self.emit.write(")");
689 }
690
691 Expr::StringInterp { template, .. } => {
692 self.emit_string_template(template);
693 }
694
695 Expr::Match {
697 scrutinee, arms, ..
698 } => {
699 self.emit.write("match ");
700 self.generate_expr(scrutinee);
701 self.emit.writeln(" {");
702 self.emit.indent();
703 for arm in arms {
704 self.emit_pattern(&arm.pattern);
705 self.emit.write(" => ");
706 self.generate_expr(&arm.body);
707 self.emit.writeln(",");
708 }
709 self.emit.dedent();
710 self.emit.write("}");
711 }
712
713 Expr::RecordConstruct { name, fields, .. } => {
715 self.emit.write(&name.name);
716 self.emit.write(" { ");
717 for (i, field) in fields.iter().enumerate() {
718 if i > 0 {
719 self.emit.write(", ");
720 }
721 self.emit.write(&field.name.name);
722 self.emit.write(": ");
723 self.generate_expr(&field.value);
724 }
725 self.emit.write(" }");
726 }
727
728 Expr::FieldAccess { object, field, .. } => {
730 self.generate_expr(object);
731 self.emit.write(".");
732 self.emit.write(&field.name);
733 }
734
735 Expr::Receive { .. } => {
736 self.emit.write("ctx.receive().await?");
737 }
738
739 Expr::Try { expr, .. } => {
741 self.generate_expr(expr);
743 self.emit.write("?");
744 }
745
746 Expr::Catch {
747 expr,
748 error_bind,
749 recovery,
750 ..
751 } => {
752 self.emit.write("match ");
754 self.generate_expr(expr);
755 self.emit.writeln(" {");
756 self.emit.indent();
757
758 self.emit.writeln("Ok(__val) => __val,");
760
761 if let Some(err_name) = error_bind {
763 self.emit.write("Err(");
764 self.emit.write(&err_name.name);
765 self.emit.write(") => ");
766 } else {
767 self.emit.write("Err(_) => ");
768 }
769 self.generate_expr(recovery);
770 self.emit.writeln(",");
771
772 self.emit.dedent();
773 self.emit.write("}");
774 }
775 }
776 }
777
778 fn emit_pattern(&mut self, pattern: &sage_parser::Pattern) {
779 use sage_parser::Pattern;
780 match pattern {
781 Pattern::Wildcard { .. } => {
782 self.emit.write("_");
783 }
784 Pattern::Variant {
785 enum_name, variant, ..
786 } => {
787 if let Some(enum_name) = enum_name {
788 self.emit.write(&enum_name.name);
789 self.emit.write("::");
790 }
791 self.emit.write(&variant.name);
792 }
793 Pattern::Literal { value, .. } => {
794 self.emit_literal(value);
795 }
796 Pattern::Binding { name, .. } => {
797 self.emit.write(&name.name);
798 }
799 }
800 }
801
802 fn emit_literal(&mut self, lit: &Literal) {
803 match lit {
804 Literal::Int(n) => {
805 self.emit.write(&format!("{n}_i64"));
806 }
807 Literal::Float(f) => {
808 self.emit.write(&format!("{f}_f64"));
809 }
810 Literal::Bool(b) => {
811 self.emit.write(if *b { "true" } else { "false" });
812 }
813 Literal::String(s) => {
814 self.emit.write("\"");
816 for c in s.chars() {
817 match c {
818 '"' => self.emit.write_raw("\\\""),
819 '\\' => self.emit.write_raw("\\\\"),
820 '\n' => self.emit.write_raw("\\n"),
821 '\r' => self.emit.write_raw("\\r"),
822 '\t' => self.emit.write_raw("\\t"),
823 _ => self.emit.write_raw(&c.to_string()),
824 }
825 }
826 self.emit.write("\".to_string()");
827 }
828 }
829 }
830
831 fn emit_string_template(&mut self, template: &sage_parser::StringTemplate) {
832 if !template.has_interpolations() {
833 if let Some(StringPart::Literal(s)) = template.parts.first() {
835 self.emit.write("\"");
836 self.emit.write_raw(s);
837 self.emit.write("\".to_string()");
838 }
839 return;
840 }
841
842 self.emit.write("format!(\"");
844 for part in &template.parts {
845 match part {
846 StringPart::Literal(s) => {
847 let escaped = s.replace('{', "{{").replace('}', "}}");
849 self.emit.write_raw(&escaped);
850 }
851 StringPart::Interpolation(_) => {
852 self.emit.write_raw("{}");
853 }
854 }
855 }
856 self.emit.write("\"");
857
858 for part in &template.parts {
860 if let StringPart::Interpolation(ident) = part {
861 self.emit.write(", ");
862 self.emit.write(&ident.name);
863 }
864 }
865 self.emit.write(")");
866 }
867
868 fn emit_type(&mut self, ty: &TypeExpr) {
869 match ty {
870 TypeExpr::Int => self.emit.write("i64"),
871 TypeExpr::Float => self.emit.write("f64"),
872 TypeExpr::Bool => self.emit.write("bool"),
873 TypeExpr::String => self.emit.write("String"),
874 TypeExpr::Unit => self.emit.write("()"),
875 TypeExpr::List(inner) => {
876 self.emit.write("Vec<");
877 self.emit_type(inner);
878 self.emit.write(">");
879 }
880 TypeExpr::Option(inner) => {
881 self.emit.write("Option<");
882 self.emit_type(inner);
883 self.emit.write(">");
884 }
885 TypeExpr::Inferred(inner) => {
886 self.emit_type(inner);
888 }
889 TypeExpr::Agent(agent_name) => {
890 self.emit.write("AgentHandle<");
893 self.emit.write(&agent_name.name);
894 self.emit.write("Output>");
895 }
896 TypeExpr::Named(name) => {
897 self.emit.write(&name.name);
898 }
899
900 TypeExpr::Error => {
902 self.emit.write("sage_runtime::SageError");
903 }
904 }
905 }
906
907 fn emit_binop(&mut self, op: &BinOp) {
908 let s = match op {
909 BinOp::Add => "+",
910 BinOp::Sub => "-",
911 BinOp::Mul => "*",
912 BinOp::Div => "/",
913 BinOp::Eq => "==",
914 BinOp::Ne => "!=",
915 BinOp::Lt => "<",
916 BinOp::Gt => ">",
917 BinOp::Le => "<=",
918 BinOp::Ge => ">=",
919 BinOp::And => "&&",
920 BinOp::Or => "||",
921 BinOp::Concat => "++", };
923 self.emit.write(s);
924 }
925
926 fn emit_unaryop(&mut self, op: &UnaryOp) {
927 let s = match op {
928 UnaryOp::Neg => "-",
929 UnaryOp::Not => "!",
930 };
931 self.emit.write(s);
932 }
933
934 fn infer_agent_output_type(&self, agent: &AgentDecl) -> String {
935 for handler in &agent.handlers {
938 if let EventKind::Start = &handler.event {
939 if let Some(ty) = self.find_emit_type(&handler.body) {
940 return ty;
941 }
942 }
943 }
944 "i64".to_string()
945 }
946
947 fn find_emit_type(&self, block: &Block) -> Option<String> {
948 for stmt in &block.stmts {
949 if let Stmt::Expr { expr, .. } = stmt {
950 if let Expr::Emit { value, .. } = expr {
951 return Some(self.infer_expr_type(value));
952 }
953 }
954 if let Stmt::If {
956 then_block,
957 else_block,
958 ..
959 } = stmt
960 {
961 if let Some(ty) = self.find_emit_type(then_block) {
962 return Some(ty);
963 }
964 if let Some(else_branch) = else_block {
965 if let sage_parser::ElseBranch::Block(block) = else_branch {
966 if let Some(ty) = self.find_emit_type(block) {
967 return Some(ty);
968 }
969 }
970 }
971 }
972 }
973 None
974 }
975
976 fn infer_expr_type(&self, expr: &Expr) -> String {
977 match expr {
978 Expr::Literal { value, .. } => match value {
979 Literal::Int(_) => "i64".to_string(),
980 Literal::Float(_) => "f64".to_string(),
981 Literal::Bool(_) => "bool".to_string(),
982 Literal::String(_) => "String".to_string(),
983 },
984 Expr::Var { .. } => "i64".to_string(), Expr::Binary { op, .. } => {
986 if matches!(
987 op,
988 BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge
989 ) {
990 "bool".to_string()
991 } else if matches!(op, BinOp::Concat) {
992 "String".to_string()
993 } else {
994 "i64".to_string()
995 }
996 }
997 Expr::Infer { .. } | Expr::StringInterp { .. } => "String".to_string(),
998 Expr::Call { name, .. } if name.name == "str" => "String".to_string(),
999 Expr::Call { name, .. } if name.name == "len" => "i64".to_string(),
1000 _ => "i64".to_string(),
1001 }
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::*;
1008 use sage_lexer::lex;
1009 use sage_parser::parse;
1010 use std::sync::Arc;
1011
1012 fn generate_source(source: &str) -> String {
1013 let lex_result = lex(source).expect("lexing failed");
1014 let source_arc: Arc<str> = Arc::from(source);
1015 let (program, errors) = parse(lex_result.tokens(), source_arc);
1016 assert!(errors.is_empty(), "parse errors: {errors:?}");
1017 let program = program.expect("should parse");
1018 generate(&program, "test").main_rs
1019 }
1020
1021 #[test]
1022 fn generate_minimal_program() {
1023 let source = r#"
1024 agent Main {
1025 on start {
1026 emit(42);
1027 }
1028 }
1029 run Main;
1030 "#;
1031
1032 let output = generate_source(source);
1033 assert!(output.contains("struct Main;"));
1034 assert!(output.contains("async fn on_start"));
1035 assert!(output.contains("ctx.emit(42_i64)"));
1036 assert!(output.contains("#[tokio::main]"));
1037 }
1038
1039 #[test]
1040 fn generate_function() {
1041 let source = r#"
1042 fn add(a: Int, b: Int) -> Int {
1043 return a + b;
1044 }
1045 agent Main {
1046 on start {
1047 emit(add(1, 2));
1048 }
1049 }
1050 run Main;
1051 "#;
1052
1053 let output = generate_source(source);
1054 assert!(output.contains("fn add(a: i64, b: i64) -> i64"));
1055 assert!(output.contains("return (a + b);"));
1056 }
1057
1058 #[test]
1059 fn generate_agent_with_beliefs() {
1060 let source = r#"
1061 agent Worker {
1062 value: Int
1063
1064 on start {
1065 emit(self.value * 2);
1066 }
1067 }
1068 agent Main {
1069 on start {
1070 emit(0);
1071 }
1072 }
1073 run Main;
1074 "#;
1075
1076 let output = generate_source(source);
1077 assert!(output.contains("struct Worker {"));
1078 assert!(output.contains("value: i64,"));
1079 assert!(output.contains("self.value"));
1080 }
1081
1082 #[test]
1083 fn generate_string_interpolation() {
1084 let source = r#"
1085 agent Main {
1086 on start {
1087 let name = "World";
1088 let msg = "Hello, {name}!";
1089 print(msg);
1090 emit(0);
1091 }
1092 }
1093 run Main;
1094 "#;
1095
1096 let output = generate_source(source);
1097 assert!(output.contains("format!(\"Hello, {}!\", name)"));
1098 }
1099
1100 #[test]
1101 fn generate_control_flow() {
1102 let source = r#"
1103 agent Main {
1104 on start {
1105 let x = 10;
1106 if x > 5 {
1107 emit(1);
1108 } else {
1109 emit(0);
1110 }
1111 }
1112 }
1113 run Main;
1114 "#;
1115
1116 let output = generate_source(source);
1117 assert!(output.contains("if (x > 5_i64)"), "output:\n{output}");
1118 assert!(output.contains("else"), "output:\n{output}");
1120 }
1121
1122 #[test]
1123 fn generate_loops() {
1124 let source = r#"
1125 agent Main {
1126 on start {
1127 for x in [1, 2, 3] {
1128 print(str(x));
1129 }
1130 let n = 0;
1131 while n < 5 {
1132 n = n + 1;
1133 }
1134 emit(n);
1135 }
1136 }
1137 run Main;
1138 "#;
1139
1140 let output = generate_source(source);
1141 assert!(output.contains("for x in vec![1_i64, 2_i64, 3_i64]"));
1142 assert!(output.contains("while (n < 5_i64)"));
1143 }
1144
1145 #[test]
1146 fn generate_pub_function() {
1147 let source = r#"
1148 pub fn helper(x: Int) -> Int {
1149 return x * 2;
1150 }
1151 agent Main {
1152 on start {
1153 emit(helper(21));
1154 }
1155 }
1156 run Main;
1157 "#;
1158
1159 let output = generate_source(source);
1160 assert!(output.contains("pub fn helper(x: i64) -> i64"));
1161 }
1162
1163 #[test]
1164 fn generate_pub_agent() {
1165 let source = r#"
1166 pub agent Worker {
1167 on start {
1168 emit(42);
1169 }
1170 }
1171 agent Main {
1172 on start {
1173 emit(0);
1174 }
1175 }
1176 run Main;
1177 "#;
1178
1179 let output = generate_source(source);
1180 assert!(output.contains("pub struct Worker;"));
1181 }
1182
1183 #[test]
1184 fn generate_module_tree_simple() {
1185 use sage_loader::load_single_file;
1186 use std::fs;
1187 use tempfile::TempDir;
1188
1189 let dir = TempDir::new().unwrap();
1190 let file = dir.path().join("test.sg");
1191 fs::write(
1192 &file,
1193 r#"
1194agent Main {
1195 on start {
1196 emit(42);
1197 }
1198}
1199run Main;
1200"#,
1201 )
1202 .unwrap();
1203
1204 let tree = load_single_file(&file).unwrap();
1205 let project = generate_module_tree(&tree, "test");
1206
1207 assert!(project.main_rs.contains("struct Main;"));
1208 assert!(project.main_rs.contains("async fn on_start"));
1209 assert!(project.main_rs.contains("#[tokio::main]"));
1210 }
1211
1212 #[test]
1213 fn generate_record_declaration() {
1214 let source = r#"
1215 record Point {
1216 x: Int,
1217 y: Int,
1218 }
1219 agent Main {
1220 on start {
1221 let p = Point { x: 10, y: 20 };
1222 emit(p.x);
1223 }
1224 }
1225 run Main;
1226 "#;
1227
1228 let output = generate_source(source);
1229 assert!(output.contains("#[derive(Debug, Clone)]"));
1230 assert!(output.contains("struct Point {"));
1231 assert!(output.contains("x: i64,"));
1232 assert!(output.contains("y: i64,"));
1233 assert!(output.contains("Point { x: 10_i64, y: 20_i64 }"));
1234 assert!(output.contains("p.x"));
1235 }
1236
1237 #[test]
1238 fn generate_enum_declaration() {
1239 let source = r#"
1240 enum Status {
1241 Active,
1242 Inactive,
1243 Pending,
1244 }
1245 agent Main {
1246 on start {
1247 emit(0);
1248 }
1249 }
1250 run Main;
1251 "#;
1252
1253 let output = generate_source(source);
1254 assert!(output.contains("#[derive(Debug, Clone, Copy, PartialEq, Eq)]"));
1255 assert!(output.contains("enum Status {"));
1256 assert!(output.contains("Active,"));
1257 assert!(output.contains("Inactive,"));
1258 assert!(output.contains("Pending,"));
1259 }
1260
1261 #[test]
1262 fn generate_const_declaration() {
1263 let source = r#"
1264 const MAX_SIZE: Int = 100;
1265 const GREETING: String = "Hello";
1266 agent Main {
1267 on start {
1268 emit(MAX_SIZE);
1269 }
1270 }
1271 run Main;
1272 "#;
1273
1274 let output = generate_source(source);
1275 assert!(output.contains("const MAX_SIZE: i64 = 100_i64;"));
1276 assert!(output.contains("const GREETING: String = \"Hello\".to_string();"));
1277 }
1278
1279 #[test]
1280 fn generate_match_expression() {
1281 let source = r#"
1282 enum Status {
1283 Active,
1284 Inactive,
1285 }
1286 fn check_status(s: Status) -> Int {
1287 return match s {
1288 Active => 1,
1289 Inactive => 0,
1290 };
1291 }
1292 agent Main {
1293 on start {
1294 emit(0);
1295 }
1296 }
1297 run Main;
1298 "#;
1299
1300 let output = generate_source(source);
1301 assert!(output.contains("match s {"));
1302 assert!(output.contains("Active => 1_i64,"));
1303 assert!(output.contains("Inactive => 0_i64,"));
1304 }
1305
1306 #[test]
1311 fn generate_fallible_function() {
1312 let source = r#"
1313 fn get_data(url: String) -> String fails {
1314 return url;
1315 }
1316 agent Main {
1317 on start { emit(0); }
1318 }
1319 run Main;
1320 "#;
1321
1322 let output = generate_source(source);
1323 assert!(output.contains("fn get_data(url: String) -> SageResult<String>"));
1325 }
1326
1327 #[test]
1328 fn generate_try_expression() {
1329 let source = r#"
1330 fn fallible() -> Int fails { return 42; }
1331 fn caller() -> Int fails {
1332 let x = try fallible();
1333 return x;
1334 }
1335 agent Main {
1336 on start { emit(0); }
1337 }
1338 run Main;
1339 "#;
1340
1341 let output = generate_source(source);
1342 assert!(output.contains("fallible()?"));
1344 }
1345
1346 #[test]
1347 fn generate_catch_expression() {
1348 let source = r#"
1349 fn fallible() -> Int fails { return 42; }
1350 agent Main {
1351 on start {
1352 let x = fallible() catch { 0 };
1353 emit(x);
1354 }
1355 }
1356 run Main;
1357 "#;
1358
1359 let output = generate_source(source);
1360 assert!(output.contains("match fallible()"));
1362 assert!(output.contains("Ok(__val) => __val"));
1363 assert!(output.contains("Err(_) => 0_i64"));
1364 }
1365
1366 #[test]
1367 fn generate_catch_with_binding() {
1368 let source = r#"
1369 fn fallible() -> Int fails { return 42; }
1370 agent Main {
1371 on start {
1372 let x = fallible() catch(e) { 0 };
1373 emit(x);
1374 }
1375 }
1376 run Main;
1377 "#;
1378
1379 let output = generate_source(source);
1380 assert!(output.contains("Err(e) => 0_i64"));
1382 }
1383
1384 #[test]
1385 fn generate_on_error_handler() {
1386 let source = r#"
1387 agent Main {
1388 on start {
1389 emit(0);
1390 }
1391 on error(e) {
1392 emit(1);
1393 }
1394 }
1395 run Main;
1396 "#;
1397
1398 let output = generate_source(source);
1399 assert!(output.contains("async fn on_error(self, e: SageError"));
1401 assert!(output.contains(".on_error(e, ctx)"));
1403 }
1404}