Skip to main content

sage_parser/
ast.rs

1//! Abstract Syntax Tree definitions for the Sage language.
2//!
3//! This module defines all AST node types that the parser produces.
4//! Every node carries a `Span` for error reporting.
5
6use sage_types::{Ident, Span, TypeExpr};
7use std::fmt;
8
9// =============================================================================
10// Program (top-level)
11// =============================================================================
12
13/// A complete Sage program.
14#[derive(Debug, Clone, PartialEq)]
15pub struct Program {
16    /// Agent declarations.
17    pub agents: Vec<AgentDecl>,
18    /// Function declarations.
19    pub functions: Vec<FnDecl>,
20    /// The entry-point agent (from `run AgentName`).
21    pub run_agent: Ident,
22    /// Span covering the entire program.
23    pub span: Span,
24}
25
26// =============================================================================
27// Agent declarations
28// =============================================================================
29
30/// An agent declaration: `agent Name { ... }`
31#[derive(Debug, Clone, PartialEq)]
32pub struct AgentDecl {
33    /// The agent's name.
34    pub name: Ident,
35    /// Belief declarations (agent state).
36    pub beliefs: Vec<BeliefDecl>,
37    /// Event handlers.
38    pub handlers: Vec<HandlerDecl>,
39    /// Span covering the entire declaration.
40    pub span: Span,
41}
42
43/// A belief declaration: `belief name: Type`
44#[derive(Debug, Clone, PartialEq)]
45pub struct BeliefDecl {
46    /// The belief's name.
47    pub name: Ident,
48    /// The belief's type.
49    pub ty: TypeExpr,
50    /// Span covering the declaration.
51    pub span: Span,
52}
53
54/// An event handler: `on start { ... }`, `on message(x: T) { ... }`, `on stop { ... }`
55#[derive(Debug, Clone, PartialEq)]
56pub struct HandlerDecl {
57    /// The event kind this handler responds to.
58    pub event: EventKind,
59    /// The handler body.
60    pub body: Block,
61    /// Span covering the entire handler.
62    pub span: Span,
63}
64
65/// The kind of event a handler responds to.
66#[derive(Debug, Clone, PartialEq)]
67pub enum EventKind {
68    /// `on start` — runs when the agent is spawned.
69    Start,
70    /// `on message(param: Type)` — runs when a message is received.
71    Message {
72        /// The parameter name for the incoming message.
73        param_name: Ident,
74        /// The type of the message.
75        param_ty: TypeExpr,
76    },
77    /// `on stop` — runs during graceful shutdown.
78    Stop,
79}
80
81impl fmt::Display for EventKind {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            EventKind::Start => write!(f, "start"),
85            EventKind::Message {
86                param_name,
87                param_ty,
88            } => {
89                write!(f, "message({param_name}: {param_ty})")
90            }
91            EventKind::Stop => write!(f, "stop"),
92        }
93    }
94}
95
96// =============================================================================
97// Function declarations
98// =============================================================================
99
100/// A function declaration: `fn name(params) -> ReturnType { ... }`
101#[derive(Debug, Clone, PartialEq)]
102pub struct FnDecl {
103    /// The function's name.
104    pub name: Ident,
105    /// The function's parameters.
106    pub params: Vec<Param>,
107    /// The return type.
108    pub return_ty: TypeExpr,
109    /// The function body.
110    pub body: Block,
111    /// Span covering the entire declaration.
112    pub span: Span,
113}
114
115/// A function parameter: `name: Type`
116#[derive(Debug, Clone, PartialEq)]
117pub struct Param {
118    /// The parameter name.
119    pub name: Ident,
120    /// The parameter type.
121    pub ty: TypeExpr,
122    /// Span covering the parameter.
123    pub span: Span,
124}
125
126// =============================================================================
127// Blocks and statements
128// =============================================================================
129
130/// A block of statements: `{ stmt* }`
131#[derive(Debug, Clone, PartialEq)]
132pub struct Block {
133    /// The statements in this block.
134    pub stmts: Vec<Stmt>,
135    /// Span covering the entire block (including braces).
136    pub span: Span,
137}
138
139/// A statement.
140#[derive(Debug, Clone, PartialEq)]
141pub enum Stmt {
142    /// Variable binding: `let name: Type = expr` or `let name = expr`
143    Let {
144        /// The variable name.
145        name: Ident,
146        /// Optional type annotation.
147        ty: Option<TypeExpr>,
148        /// The initial value.
149        value: Expr,
150        /// Span covering the statement.
151        span: Span,
152    },
153
154    /// Assignment: `name = expr`
155    Assign {
156        /// The variable being assigned to.
157        name: Ident,
158        /// The new value.
159        value: Expr,
160        /// Span covering the statement.
161        span: Span,
162    },
163
164    /// Return statement: `return expr?`
165    Return {
166        /// The optional return value.
167        value: Option<Expr>,
168        /// Span covering the statement.
169        span: Span,
170    },
171
172    /// If statement: `if cond { ... } else { ... }`
173    If {
174        /// The condition (must be Bool).
175        condition: Expr,
176        /// The then branch.
177        then_block: Block,
178        /// The optional else branch (can be another If for else-if chains).
179        else_block: Option<ElseBranch>,
180        /// Span covering the statement.
181        span: Span,
182    },
183
184    /// For loop: `for x in iter { ... }`
185    For {
186        /// The loop variable.
187        var: Ident,
188        /// The iterable expression (must be List<T>).
189        iter: Expr,
190        /// The loop body.
191        body: Block,
192        /// Span covering the statement.
193        span: Span,
194    },
195
196    /// While loop: `while cond { ... }`
197    While {
198        /// The condition (must be Bool).
199        condition: Expr,
200        /// The loop body.
201        body: Block,
202        /// Span covering the statement.
203        span: Span,
204    },
205
206    /// Expression statement: `expr`
207    Expr {
208        /// The expression.
209        expr: Expr,
210        /// Span covering the statement.
211        span: Span,
212    },
213}
214
215impl Stmt {
216    /// Get the span of this statement.
217    #[must_use]
218    pub fn span(&self) -> &Span {
219        match self {
220            Stmt::Let { span, .. }
221            | Stmt::Assign { span, .. }
222            | Stmt::Return { span, .. }
223            | Stmt::If { span, .. }
224            | Stmt::For { span, .. }
225            | Stmt::While { span, .. }
226            | Stmt::Expr { span, .. } => span,
227        }
228    }
229}
230
231/// The else branch of an if statement.
232#[derive(Debug, Clone, PartialEq)]
233pub enum ElseBranch {
234    /// `else { ... }`
235    Block(Block),
236    /// `else if ...` (chained if)
237    ElseIf(Box<Stmt>),
238}
239
240// =============================================================================
241// Expressions
242// =============================================================================
243
244/// An expression.
245#[derive(Debug, Clone, PartialEq)]
246pub enum Expr {
247    /// LLM inference: `infer("template")` or `infer("template" -> Type)`
248    Infer {
249        /// The prompt template (may contain `{ident}` interpolations).
250        template: StringTemplate,
251        /// Optional result type annotation.
252        result_ty: Option<TypeExpr>,
253        /// Span covering the expression.
254        span: Span,
255    },
256
257    /// Agent spawning: `spawn AgentName { field: value, ... }`
258    Spawn {
259        /// The agent type to spawn.
260        agent: Ident,
261        /// Initial belief values.
262        fields: Vec<FieldInit>,
263        /// Span covering the expression.
264        span: Span,
265    },
266
267    /// Await: `await expr`
268    Await {
269        /// The agent handle to await.
270        handle: Box<Expr>,
271        /// Span covering the expression.
272        span: Span,
273    },
274
275    /// Send message: `send(handle, message)`
276    Send {
277        /// The agent handle to send to.
278        handle: Box<Expr>,
279        /// The message to send.
280        message: Box<Expr>,
281        /// Span covering the expression.
282        span: Span,
283    },
284
285    /// Emit value: `emit(value)`
286    Emit {
287        /// The value to emit to the awaiter.
288        value: Box<Expr>,
289        /// Span covering the expression.
290        span: Span,
291    },
292
293    /// Function call: `name(args)`
294    Call {
295        /// The function name.
296        name: Ident,
297        /// The arguments.
298        args: Vec<Expr>,
299        /// Span covering the expression.
300        span: Span,
301    },
302
303    /// Method call on self: `self.method(args)`
304    SelfMethodCall {
305        /// The method name.
306        method: Ident,
307        /// The arguments.
308        args: Vec<Expr>,
309        /// Span covering the expression.
310        span: Span,
311    },
312
313    /// Self field access: `self.field`
314    SelfField {
315        /// The field (belief) name.
316        field: Ident,
317        /// Span covering the expression.
318        span: Span,
319    },
320
321    /// Binary operation: `left op right`
322    Binary {
323        /// The operator.
324        op: BinOp,
325        /// The left operand.
326        left: Box<Expr>,
327        /// The right operand.
328        right: Box<Expr>,
329        /// Span covering the expression.
330        span: Span,
331    },
332
333    /// Unary operation: `op operand`
334    Unary {
335        /// The operator.
336        op: UnaryOp,
337        /// The operand.
338        operand: Box<Expr>,
339        /// Span covering the expression.
340        span: Span,
341    },
342
343    /// List literal: `[a, b, c]`
344    List {
345        /// The list elements.
346        elements: Vec<Expr>,
347        /// Span covering the expression.
348        span: Span,
349    },
350
351    /// Literal value.
352    Literal {
353        /// The literal value.
354        value: Literal,
355        /// Span covering the expression.
356        span: Span,
357    },
358
359    /// Variable reference.
360    Var {
361        /// The variable name.
362        name: Ident,
363        /// Span covering the expression.
364        span: Span,
365    },
366
367    /// Parenthesized expression: `(expr)`
368    Paren {
369        /// The inner expression.
370        inner: Box<Expr>,
371        /// Span covering the expression (including parens).
372        span: Span,
373    },
374
375    /// Interpolated string: `"Hello, {name}!"`
376    StringInterp {
377        /// The string template with interpolations.
378        template: StringTemplate,
379        /// Span covering the expression.
380        span: Span,
381    },
382}
383
384impl Expr {
385    /// Get the span of this expression.
386    #[must_use]
387    pub fn span(&self) -> &Span {
388        match self {
389            Expr::Infer { span, .. }
390            | Expr::Spawn { span, .. }
391            | Expr::Await { span, .. }
392            | Expr::Send { span, .. }
393            | Expr::Emit { span, .. }
394            | Expr::Call { span, .. }
395            | Expr::SelfMethodCall { span, .. }
396            | Expr::SelfField { span, .. }
397            | Expr::Binary { span, .. }
398            | Expr::Unary { span, .. }
399            | Expr::List { span, .. }
400            | Expr::Literal { span, .. }
401            | Expr::Var { span, .. }
402            | Expr::Paren { span, .. }
403            | Expr::StringInterp { span, .. } => span,
404        }
405    }
406}
407
408/// A field initialization in a spawn expression: `field: value`
409#[derive(Debug, Clone, PartialEq)]
410pub struct FieldInit {
411    /// The field (belief) name.
412    pub name: Ident,
413    /// The initial value.
414    pub value: Expr,
415    /// Span covering the field initialization.
416    pub span: Span,
417}
418
419// =============================================================================
420// Operators
421// =============================================================================
422
423/// Binary operators.
424#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
425pub enum BinOp {
426    // Arithmetic
427    /// `+`
428    Add,
429    /// `-`
430    Sub,
431    /// `*`
432    Mul,
433    /// `/`
434    Div,
435
436    // Comparison
437    /// `==`
438    Eq,
439    /// `!=`
440    Ne,
441    /// `<`
442    Lt,
443    /// `>`
444    Gt,
445    /// `<=`
446    Le,
447    /// `>=`
448    Ge,
449
450    // Logical
451    /// `&&`
452    And,
453    /// `||`
454    Or,
455
456    // String
457    /// `++` (string concatenation)
458    Concat,
459}
460
461impl BinOp {
462    /// Get the precedence of this operator (higher = binds tighter).
463    #[must_use]
464    pub fn precedence(self) -> u8 {
465        match self {
466            BinOp::Or => 1,
467            BinOp::And => 2,
468            BinOp::Eq | BinOp::Ne => 3,
469            BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge => 4,
470            BinOp::Concat => 5,
471            BinOp::Add | BinOp::Sub => 6,
472            BinOp::Mul | BinOp::Div => 7,
473        }
474    }
475
476    /// Check if this operator is left-associative.
477    #[must_use]
478    pub fn is_left_assoc(self) -> bool {
479        // All our operators are left-associative
480        true
481    }
482}
483
484impl fmt::Display for BinOp {
485    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
486        match self {
487            BinOp::Add => write!(f, "+"),
488            BinOp::Sub => write!(f, "-"),
489            BinOp::Mul => write!(f, "*"),
490            BinOp::Div => write!(f, "/"),
491            BinOp::Eq => write!(f, "=="),
492            BinOp::Ne => write!(f, "!="),
493            BinOp::Lt => write!(f, "<"),
494            BinOp::Gt => write!(f, ">"),
495            BinOp::Le => write!(f, "<="),
496            BinOp::Ge => write!(f, ">="),
497            BinOp::And => write!(f, "&&"),
498            BinOp::Or => write!(f, "||"),
499            BinOp::Concat => write!(f, "++"),
500        }
501    }
502}
503
504/// Unary operators.
505#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
506pub enum UnaryOp {
507    /// `-` (negation)
508    Neg,
509    /// `!` (logical not)
510    Not,
511}
512
513impl fmt::Display for UnaryOp {
514    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
515        match self {
516            UnaryOp::Neg => write!(f, "-"),
517            UnaryOp::Not => write!(f, "!"),
518        }
519    }
520}
521
522// =============================================================================
523// Literals
524// =============================================================================
525
526/// A literal value.
527#[derive(Debug, Clone, PartialEq)]
528pub enum Literal {
529    /// Integer literal: `42`, `-7`
530    Int(i64),
531    /// Float literal: `3.14`, `-0.5`
532    Float(f64),
533    /// Boolean literal: `true`, `false`
534    Bool(bool),
535    /// String literal: `"hello"`
536    String(String),
537}
538
539impl fmt::Display for Literal {
540    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541        match self {
542            Literal::Int(n) => write!(f, "{n}"),
543            Literal::Float(n) => write!(f, "{n}"),
544            Literal::Bool(b) => write!(f, "{b}"),
545            Literal::String(s) => write!(f, "\"{s}\""),
546        }
547    }
548}
549
550// =============================================================================
551// String templates (for interpolation)
552// =============================================================================
553
554/// A string template that may contain interpolations.
555///
556/// For example, `"Hello, {name}!"` becomes:
557/// ```text
558/// StringTemplate {
559///     parts: [
560///         StringPart::Literal("Hello, "),
561///         StringPart::Interpolation(Ident("name")),
562///         StringPart::Literal("!"),
563///     ]
564/// }
565/// ```
566#[derive(Debug, Clone, PartialEq)]
567pub struct StringTemplate {
568    /// The parts of the template.
569    pub parts: Vec<StringPart>,
570    /// Span covering the entire template string.
571    pub span: Span,
572}
573
574impl StringTemplate {
575    /// Create a simple template with no interpolations.
576    #[must_use]
577    pub fn literal(s: String, span: Span) -> Self {
578        Self {
579            parts: vec![StringPart::Literal(s)],
580            span,
581        }
582    }
583
584    /// Check if this template has any interpolations.
585    #[must_use]
586    pub fn has_interpolations(&self) -> bool {
587        self.parts
588            .iter()
589            .any(|p| matches!(p, StringPart::Interpolation(_)))
590    }
591
592    /// Get all interpolated identifiers.
593    pub fn interpolations(&self) -> impl Iterator<Item = &Ident> {
594        self.parts.iter().filter_map(|p| match p {
595            StringPart::Interpolation(ident) => Some(ident),
596            StringPart::Literal(_) => None,
597        })
598    }
599}
600
601impl fmt::Display for StringTemplate {
602    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
603        write!(f, "\"")?;
604        for part in &self.parts {
605            match part {
606                StringPart::Literal(s) => write!(f, "{s}")?,
607                StringPart::Interpolation(ident) => write!(f, "{{{ident}}}")?,
608            }
609        }
610        write!(f, "\"")
611    }
612}
613
614/// A part of a string template.
615#[derive(Debug, Clone, PartialEq)]
616pub enum StringPart {
617    /// A literal string segment.
618    Literal(String),
619    /// An interpolated identifier: `{ident}`
620    Interpolation(Ident),
621}
622
623// =============================================================================
624// Tests
625// =============================================================================
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn binop_precedence() {
633        // Mul/Div > Add/Sub > Comparison > And > Or
634        assert!(BinOp::Mul.precedence() > BinOp::Add.precedence());
635        assert!(BinOp::Add.precedence() > BinOp::Lt.precedence());
636        assert!(BinOp::Lt.precedence() > BinOp::And.precedence());
637        assert!(BinOp::And.precedence() > BinOp::Or.precedence());
638    }
639
640    #[test]
641    fn binop_display() {
642        assert_eq!(format!("{}", BinOp::Add), "+");
643        assert_eq!(format!("{}", BinOp::Eq), "==");
644        assert_eq!(format!("{}", BinOp::Concat), "++");
645        assert_eq!(format!("{}", BinOp::And), "&&");
646    }
647
648    #[test]
649    fn unaryop_display() {
650        assert_eq!(format!("{}", UnaryOp::Neg), "-");
651        assert_eq!(format!("{}", UnaryOp::Not), "!");
652    }
653
654    #[test]
655    fn literal_display() {
656        assert_eq!(format!("{}", Literal::Int(42)), "42");
657        assert_eq!(format!("{}", Literal::Float(3.14)), "3.14");
658        assert_eq!(format!("{}", Literal::Bool(true)), "true");
659        assert_eq!(format!("{}", Literal::String("hello".into())), "\"hello\"");
660    }
661
662    #[test]
663    fn event_kind_display() {
664        assert_eq!(format!("{}", EventKind::Start), "start");
665        assert_eq!(format!("{}", EventKind::Stop), "stop");
666
667        let msg = EventKind::Message {
668            param_name: Ident::dummy("msg"),
669            param_ty: TypeExpr::String,
670        };
671        assert_eq!(format!("{msg}"), "message(msg: String)");
672    }
673
674    #[test]
675    fn string_template_literal() {
676        let template = StringTemplate::literal("hello".into(), Span::dummy());
677        assert!(!template.has_interpolations());
678        assert_eq!(format!("{template}"), "\"hello\"");
679    }
680
681    #[test]
682    fn string_template_with_interpolation() {
683        let template = StringTemplate {
684            parts: vec![
685                StringPart::Literal("Hello, ".into()),
686                StringPart::Interpolation(Ident::dummy("name")),
687                StringPart::Literal("!".into()),
688            ],
689            span: Span::dummy(),
690        };
691        assert!(template.has_interpolations());
692        assert_eq!(format!("{template}"), "\"Hello, {name}!\"");
693
694        let interps: Vec<_> = template.interpolations().collect();
695        assert_eq!(interps.len(), 1);
696        assert_eq!(interps[0].name, "name");
697    }
698
699    #[test]
700    fn expr_span() {
701        let span = Span::dummy();
702        let expr = Expr::Literal {
703            value: Literal::Int(42),
704            span: span.clone(),
705        };
706        assert_eq!(expr.span(), &span);
707    }
708
709    #[test]
710    fn stmt_span() {
711        let span = Span::dummy();
712        let stmt = Stmt::Return {
713            value: None,
714            span: span.clone(),
715        };
716        assert_eq!(stmt.span(), &span);
717    }
718}