sage_parser/ast.rs
1//! Abstract Syntax Tree definitions for the Sage language.
2//!
3//! This module defines all AST node types that the parser produces.
4//! Every node carries a `Span` for error reporting.
5
6use sage_types::{Ident, Span, TypeExpr};
7use std::fmt;
8
9// =============================================================================
10// Program (top-level)
11// =============================================================================
12
13/// A complete Sage program.
14#[derive(Debug, Clone, PartialEq)]
15pub struct Program {
16 /// Agent declarations.
17 pub agents: Vec<AgentDecl>,
18 /// Function declarations.
19 pub functions: Vec<FnDecl>,
20 /// The entry-point agent (from `run AgentName`).
21 pub run_agent: Ident,
22 /// Span covering the entire program.
23 pub span: Span,
24}
25
26// =============================================================================
27// Agent declarations
28// =============================================================================
29
30/// An agent declaration: `agent Name { ... }`
31#[derive(Debug, Clone, PartialEq)]
32pub struct AgentDecl {
33 /// The agent's name.
34 pub name: Ident,
35 /// Belief declarations (agent state).
36 pub beliefs: Vec<BeliefDecl>,
37 /// Event handlers.
38 pub handlers: Vec<HandlerDecl>,
39 /// Span covering the entire declaration.
40 pub span: Span,
41}
42
43/// A belief declaration: `belief name: Type`
44#[derive(Debug, Clone, PartialEq)]
45pub struct BeliefDecl {
46 /// The belief's name.
47 pub name: Ident,
48 /// The belief's type.
49 pub ty: TypeExpr,
50 /// Span covering the declaration.
51 pub span: Span,
52}
53
54/// An event handler: `on start { ... }`, `on message(x: T) { ... }`, `on stop { ... }`
55#[derive(Debug, Clone, PartialEq)]
56pub struct HandlerDecl {
57 /// The event kind this handler responds to.
58 pub event: EventKind,
59 /// The handler body.
60 pub body: Block,
61 /// Span covering the entire handler.
62 pub span: Span,
63}
64
65/// The kind of event a handler responds to.
66#[derive(Debug, Clone, PartialEq)]
67pub enum EventKind {
68 /// `on start` — runs when the agent is spawned.
69 Start,
70 /// `on message(param: Type)` — runs when a message is received.
71 Message {
72 /// The parameter name for the incoming message.
73 param_name: Ident,
74 /// The type of the message.
75 param_ty: TypeExpr,
76 },
77 /// `on stop` — runs during graceful shutdown.
78 Stop,
79}
80
81impl fmt::Display for EventKind {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 EventKind::Start => write!(f, "start"),
85 EventKind::Message {
86 param_name,
87 param_ty,
88 } => {
89 write!(f, "message({param_name}: {param_ty})")
90 }
91 EventKind::Stop => write!(f, "stop"),
92 }
93 }
94}
95
96// =============================================================================
97// Function declarations
98// =============================================================================
99
100/// A function declaration: `fn name(params) -> ReturnType { ... }`
101#[derive(Debug, Clone, PartialEq)]
102pub struct FnDecl {
103 /// The function's name.
104 pub name: Ident,
105 /// The function's parameters.
106 pub params: Vec<Param>,
107 /// The return type.
108 pub return_ty: TypeExpr,
109 /// The function body.
110 pub body: Block,
111 /// Span covering the entire declaration.
112 pub span: Span,
113}
114
115/// A function parameter: `name: Type`
116#[derive(Debug, Clone, PartialEq)]
117pub struct Param {
118 /// The parameter name.
119 pub name: Ident,
120 /// The parameter type.
121 pub ty: TypeExpr,
122 /// Span covering the parameter.
123 pub span: Span,
124}
125
126// =============================================================================
127// Blocks and statements
128// =============================================================================
129
130/// A block of statements: `{ stmt* }`
131#[derive(Debug, Clone, PartialEq)]
132pub struct Block {
133 /// The statements in this block.
134 pub stmts: Vec<Stmt>,
135 /// Span covering the entire block (including braces).
136 pub span: Span,
137}
138
139/// A statement.
140#[derive(Debug, Clone, PartialEq)]
141pub enum Stmt {
142 /// Variable binding: `let name: Type = expr` or `let name = expr`
143 Let {
144 /// The variable name.
145 name: Ident,
146 /// Optional type annotation.
147 ty: Option<TypeExpr>,
148 /// The initial value.
149 value: Expr,
150 /// Span covering the statement.
151 span: Span,
152 },
153
154 /// Assignment: `name = expr`
155 Assign {
156 /// The variable being assigned to.
157 name: Ident,
158 /// The new value.
159 value: Expr,
160 /// Span covering the statement.
161 span: Span,
162 },
163
164 /// Return statement: `return expr?`
165 Return {
166 /// The optional return value.
167 value: Option<Expr>,
168 /// Span covering the statement.
169 span: Span,
170 },
171
172 /// If statement: `if cond { ... } else { ... }`
173 If {
174 /// The condition (must be Bool).
175 condition: Expr,
176 /// The then branch.
177 then_block: Block,
178 /// The optional else branch (can be another If for else-if chains).
179 else_block: Option<ElseBranch>,
180 /// Span covering the statement.
181 span: Span,
182 },
183
184 /// For loop: `for x in iter { ... }`
185 For {
186 /// The loop variable.
187 var: Ident,
188 /// The iterable expression (must be List<T>).
189 iter: Expr,
190 /// The loop body.
191 body: Block,
192 /// Span covering the statement.
193 span: Span,
194 },
195
196 /// While loop: `while cond { ... }`
197 While {
198 /// The condition (must be Bool).
199 condition: Expr,
200 /// The loop body.
201 body: Block,
202 /// Span covering the statement.
203 span: Span,
204 },
205
206 /// Expression statement: `expr`
207 Expr {
208 /// The expression.
209 expr: Expr,
210 /// Span covering the statement.
211 span: Span,
212 },
213}
214
215impl Stmt {
216 /// Get the span of this statement.
217 #[must_use]
218 pub fn span(&self) -> &Span {
219 match self {
220 Stmt::Let { span, .. }
221 | Stmt::Assign { span, .. }
222 | Stmt::Return { span, .. }
223 | Stmt::If { span, .. }
224 | Stmt::For { span, .. }
225 | Stmt::While { span, .. }
226 | Stmt::Expr { span, .. } => span,
227 }
228 }
229}
230
231/// The else branch of an if statement.
232#[derive(Debug, Clone, PartialEq)]
233pub enum ElseBranch {
234 /// `else { ... }`
235 Block(Block),
236 /// `else if ...` (chained if)
237 ElseIf(Box<Stmt>),
238}
239
240// =============================================================================
241// Expressions
242// =============================================================================
243
244/// An expression.
245#[derive(Debug, Clone, PartialEq)]
246pub enum Expr {
247 /// LLM inference: `infer("template")` or `infer("template" -> Type)`
248 Infer {
249 /// The prompt template (may contain `{ident}` interpolations).
250 template: StringTemplate,
251 /// Optional result type annotation.
252 result_ty: Option<TypeExpr>,
253 /// Span covering the expression.
254 span: Span,
255 },
256
257 /// Agent spawning: `spawn AgentName { field: value, ... }`
258 Spawn {
259 /// The agent type to spawn.
260 agent: Ident,
261 /// Initial belief values.
262 fields: Vec<FieldInit>,
263 /// Span covering the expression.
264 span: Span,
265 },
266
267 /// Await: `await expr`
268 Await {
269 /// The agent handle to await.
270 handle: Box<Expr>,
271 /// Span covering the expression.
272 span: Span,
273 },
274
275 /// Send message: `send(handle, message)`
276 Send {
277 /// The agent handle to send to.
278 handle: Box<Expr>,
279 /// The message to send.
280 message: Box<Expr>,
281 /// Span covering the expression.
282 span: Span,
283 },
284
285 /// Emit value: `emit(value)`
286 Emit {
287 /// The value to emit to the awaiter.
288 value: Box<Expr>,
289 /// Span covering the expression.
290 span: Span,
291 },
292
293 /// Function call: `name(args)`
294 Call {
295 /// The function name.
296 name: Ident,
297 /// The arguments.
298 args: Vec<Expr>,
299 /// Span covering the expression.
300 span: Span,
301 },
302
303 /// Method call on self: `self.method(args)`
304 SelfMethodCall {
305 /// The method name.
306 method: Ident,
307 /// The arguments.
308 args: Vec<Expr>,
309 /// Span covering the expression.
310 span: Span,
311 },
312
313 /// Self field access: `self.field`
314 SelfField {
315 /// The field (belief) name.
316 field: Ident,
317 /// Span covering the expression.
318 span: Span,
319 },
320
321 /// Binary operation: `left op right`
322 Binary {
323 /// The operator.
324 op: BinOp,
325 /// The left operand.
326 left: Box<Expr>,
327 /// The right operand.
328 right: Box<Expr>,
329 /// Span covering the expression.
330 span: Span,
331 },
332
333 /// Unary operation: `op operand`
334 Unary {
335 /// The operator.
336 op: UnaryOp,
337 /// The operand.
338 operand: Box<Expr>,
339 /// Span covering the expression.
340 span: Span,
341 },
342
343 /// List literal: `[a, b, c]`
344 List {
345 /// The list elements.
346 elements: Vec<Expr>,
347 /// Span covering the expression.
348 span: Span,
349 },
350
351 /// Literal value.
352 Literal {
353 /// The literal value.
354 value: Literal,
355 /// Span covering the expression.
356 span: Span,
357 },
358
359 /// Variable reference.
360 Var {
361 /// The variable name.
362 name: Ident,
363 /// Span covering the expression.
364 span: Span,
365 },
366
367 /// Parenthesized expression: `(expr)`
368 Paren {
369 /// The inner expression.
370 inner: Box<Expr>,
371 /// Span covering the expression (including parens).
372 span: Span,
373 },
374
375 /// Interpolated string: `"Hello, {name}!"`
376 StringInterp {
377 /// The string template with interpolations.
378 template: StringTemplate,
379 /// Span covering the expression.
380 span: Span,
381 },
382}
383
384impl Expr {
385 /// Get the span of this expression.
386 #[must_use]
387 pub fn span(&self) -> &Span {
388 match self {
389 Expr::Infer { span, .. }
390 | Expr::Spawn { span, .. }
391 | Expr::Await { span, .. }
392 | Expr::Send { span, .. }
393 | Expr::Emit { span, .. }
394 | Expr::Call { span, .. }
395 | Expr::SelfMethodCall { span, .. }
396 | Expr::SelfField { span, .. }
397 | Expr::Binary { span, .. }
398 | Expr::Unary { span, .. }
399 | Expr::List { span, .. }
400 | Expr::Literal { span, .. }
401 | Expr::Var { span, .. }
402 | Expr::Paren { span, .. }
403 | Expr::StringInterp { span, .. } => span,
404 }
405 }
406}
407
408/// A field initialization in a spawn expression: `field: value`
409#[derive(Debug, Clone, PartialEq)]
410pub struct FieldInit {
411 /// The field (belief) name.
412 pub name: Ident,
413 /// The initial value.
414 pub value: Expr,
415 /// Span covering the field initialization.
416 pub span: Span,
417}
418
419// =============================================================================
420// Operators
421// =============================================================================
422
423/// Binary operators.
424#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
425pub enum BinOp {
426 // Arithmetic
427 /// `+`
428 Add,
429 /// `-`
430 Sub,
431 /// `*`
432 Mul,
433 /// `/`
434 Div,
435
436 // Comparison
437 /// `==`
438 Eq,
439 /// `!=`
440 Ne,
441 /// `<`
442 Lt,
443 /// `>`
444 Gt,
445 /// `<=`
446 Le,
447 /// `>=`
448 Ge,
449
450 // Logical
451 /// `&&`
452 And,
453 /// `||`
454 Or,
455
456 // String
457 /// `++` (string concatenation)
458 Concat,
459}
460
461impl BinOp {
462 /// Get the precedence of this operator (higher = binds tighter).
463 #[must_use]
464 pub fn precedence(self) -> u8 {
465 match self {
466 BinOp::Or => 1,
467 BinOp::And => 2,
468 BinOp::Eq | BinOp::Ne => 3,
469 BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge => 4,
470 BinOp::Concat => 5,
471 BinOp::Add | BinOp::Sub => 6,
472 BinOp::Mul | BinOp::Div => 7,
473 }
474 }
475
476 /// Check if this operator is left-associative.
477 #[must_use]
478 pub fn is_left_assoc(self) -> bool {
479 // All our operators are left-associative
480 true
481 }
482}
483
484impl fmt::Display for BinOp {
485 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
486 match self {
487 BinOp::Add => write!(f, "+"),
488 BinOp::Sub => write!(f, "-"),
489 BinOp::Mul => write!(f, "*"),
490 BinOp::Div => write!(f, "/"),
491 BinOp::Eq => write!(f, "=="),
492 BinOp::Ne => write!(f, "!="),
493 BinOp::Lt => write!(f, "<"),
494 BinOp::Gt => write!(f, ">"),
495 BinOp::Le => write!(f, "<="),
496 BinOp::Ge => write!(f, ">="),
497 BinOp::And => write!(f, "&&"),
498 BinOp::Or => write!(f, "||"),
499 BinOp::Concat => write!(f, "++"),
500 }
501 }
502}
503
504/// Unary operators.
505#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
506pub enum UnaryOp {
507 /// `-` (negation)
508 Neg,
509 /// `!` (logical not)
510 Not,
511}
512
513impl fmt::Display for UnaryOp {
514 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
515 match self {
516 UnaryOp::Neg => write!(f, "-"),
517 UnaryOp::Not => write!(f, "!"),
518 }
519 }
520}
521
522// =============================================================================
523// Literals
524// =============================================================================
525
526/// A literal value.
527#[derive(Debug, Clone, PartialEq)]
528pub enum Literal {
529 /// Integer literal: `42`, `-7`
530 Int(i64),
531 /// Float literal: `3.14`, `-0.5`
532 Float(f64),
533 /// Boolean literal: `true`, `false`
534 Bool(bool),
535 /// String literal: `"hello"`
536 String(String),
537}
538
539impl fmt::Display for Literal {
540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541 match self {
542 Literal::Int(n) => write!(f, "{n}"),
543 Literal::Float(n) => write!(f, "{n}"),
544 Literal::Bool(b) => write!(f, "{b}"),
545 Literal::String(s) => write!(f, "\"{s}\""),
546 }
547 }
548}
549
550// =============================================================================
551// String templates (for interpolation)
552// =============================================================================
553
554/// A string template that may contain interpolations.
555///
556/// For example, `"Hello, {name}!"` becomes:
557/// ```text
558/// StringTemplate {
559/// parts: [
560/// StringPart::Literal("Hello, "),
561/// StringPart::Interpolation(Ident("name")),
562/// StringPart::Literal("!"),
563/// ]
564/// }
565/// ```
566#[derive(Debug, Clone, PartialEq)]
567pub struct StringTemplate {
568 /// The parts of the template.
569 pub parts: Vec<StringPart>,
570 /// Span covering the entire template string.
571 pub span: Span,
572}
573
574impl StringTemplate {
575 /// Create a simple template with no interpolations.
576 #[must_use]
577 pub fn literal(s: String, span: Span) -> Self {
578 Self {
579 parts: vec![StringPart::Literal(s)],
580 span,
581 }
582 }
583
584 /// Check if this template has any interpolations.
585 #[must_use]
586 pub fn has_interpolations(&self) -> bool {
587 self.parts
588 .iter()
589 .any(|p| matches!(p, StringPart::Interpolation(_)))
590 }
591
592 /// Get all interpolated identifiers.
593 pub fn interpolations(&self) -> impl Iterator<Item = &Ident> {
594 self.parts.iter().filter_map(|p| match p {
595 StringPart::Interpolation(ident) => Some(ident),
596 StringPart::Literal(_) => None,
597 })
598 }
599}
600
601impl fmt::Display for StringTemplate {
602 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
603 write!(f, "\"")?;
604 for part in &self.parts {
605 match part {
606 StringPart::Literal(s) => write!(f, "{s}")?,
607 StringPart::Interpolation(ident) => write!(f, "{{{ident}}}")?,
608 }
609 }
610 write!(f, "\"")
611 }
612}
613
614/// A part of a string template.
615#[derive(Debug, Clone, PartialEq)]
616pub enum StringPart {
617 /// A literal string segment.
618 Literal(String),
619 /// An interpolated identifier: `{ident}`
620 Interpolation(Ident),
621}
622
623// =============================================================================
624// Tests
625// =============================================================================
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn binop_precedence() {
633 // Mul/Div > Add/Sub > Comparison > And > Or
634 assert!(BinOp::Mul.precedence() > BinOp::Add.precedence());
635 assert!(BinOp::Add.precedence() > BinOp::Lt.precedence());
636 assert!(BinOp::Lt.precedence() > BinOp::And.precedence());
637 assert!(BinOp::And.precedence() > BinOp::Or.precedence());
638 }
639
640 #[test]
641 fn binop_display() {
642 assert_eq!(format!("{}", BinOp::Add), "+");
643 assert_eq!(format!("{}", BinOp::Eq), "==");
644 assert_eq!(format!("{}", BinOp::Concat), "++");
645 assert_eq!(format!("{}", BinOp::And), "&&");
646 }
647
648 #[test]
649 fn unaryop_display() {
650 assert_eq!(format!("{}", UnaryOp::Neg), "-");
651 assert_eq!(format!("{}", UnaryOp::Not), "!");
652 }
653
654 #[test]
655 fn literal_display() {
656 assert_eq!(format!("{}", Literal::Int(42)), "42");
657 assert_eq!(format!("{}", Literal::Float(3.14)), "3.14");
658 assert_eq!(format!("{}", Literal::Bool(true)), "true");
659 assert_eq!(format!("{}", Literal::String("hello".into())), "\"hello\"");
660 }
661
662 #[test]
663 fn event_kind_display() {
664 assert_eq!(format!("{}", EventKind::Start), "start");
665 assert_eq!(format!("{}", EventKind::Stop), "stop");
666
667 let msg = EventKind::Message {
668 param_name: Ident::dummy("msg"),
669 param_ty: TypeExpr::String,
670 };
671 assert_eq!(format!("{msg}"), "message(msg: String)");
672 }
673
674 #[test]
675 fn string_template_literal() {
676 let template = StringTemplate::literal("hello".into(), Span::dummy());
677 assert!(!template.has_interpolations());
678 assert_eq!(format!("{template}"), "\"hello\"");
679 }
680
681 #[test]
682 fn string_template_with_interpolation() {
683 let template = StringTemplate {
684 parts: vec![
685 StringPart::Literal("Hello, ".into()),
686 StringPart::Interpolation(Ident::dummy("name")),
687 StringPart::Literal("!".into()),
688 ],
689 span: Span::dummy(),
690 };
691 assert!(template.has_interpolations());
692 assert_eq!(format!("{template}"), "\"Hello, {name}!\"");
693
694 let interps: Vec<_> = template.interpolations().collect();
695 assert_eq!(interps.len(), 1);
696 assert_eq!(interps[0].name, "name");
697 }
698
699 #[test]
700 fn expr_span() {
701 let span = Span::dummy();
702 let expr = Expr::Literal {
703 value: Literal::Int(42),
704 span: span.clone(),
705 };
706 assert_eq!(expr.span(), &span);
707 }
708
709 #[test]
710 fn stmt_span() {
711 let span = Span::dummy();
712 let stmt = Stmt::Return {
713 value: None,
714 span: span.clone(),
715 };
716 assert_eq!(stmt.span(), &span);
717 }
718}