Skip to main content

trident/ast/
mod.rs

1pub mod display;
2pub mod navigate;
3
4use crate::span::Spanned;
5
6/// A parsed `.tri` file — either a program or a library module.
7#[derive(Clone, Debug)]
8pub struct File {
9    pub kind: FileKind,
10    pub name: Spanned<String>,
11    pub uses: Vec<Spanned<ModulePath>>,
12    pub declarations: Vec<Declaration>,
13    pub items: Vec<Spanned<Item>>,
14}
15
16/// Program I/O declarations.
17#[derive(Clone, Debug)]
18pub enum Declaration {
19    PubInput(Spanned<Type>),
20    PubOutput(Spanned<Type>),
21    SecInput(Spanned<Type>),
22    /// `sec ram: { addr: Type, addr: Type, ... }`
23    /// Pre-initialized RAM slots (prover-supplied secret data).
24    SecRam(Vec<(u64, Spanned<Type>)>),
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum FileKind {
29    Program,
30    Module,
31}
32
33/// A dotted module path, e.g. `std.hash` → `["std", "hash"]`.
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct ModulePath(pub Vec<String>);
36
37impl ModulePath {
38    pub fn single(name: String) -> Self {
39        Self(vec![name])
40    }
41
42    pub fn as_dotted(&self) -> String {
43        self.0.join(".")
44    }
45}
46
47impl std::fmt::Display for ModulePath {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}", self.as_dotted())
50    }
51}
52
53/// Top-level items in a module.
54#[derive(Clone, Debug)]
55pub enum Item {
56    Const(ConstDef),
57    Struct(StructDef),
58    Event(EventDef),
59    Fn(FnDef),
60}
61
62#[derive(Clone, Debug)]
63pub struct ConstDef {
64    pub is_pub: bool,
65    pub cfg: Option<Spanned<String>>,
66    pub name: Spanned<String>,
67    pub ty: Spanned<Type>,
68    pub value: Spanned<Expr>,
69}
70
71#[derive(Clone, Debug)]
72pub struct StructDef {
73    pub is_pub: bool,
74    pub cfg: Option<Spanned<String>>,
75    pub name: Spanned<String>,
76    pub fields: Vec<StructField>,
77}
78
79#[derive(Clone, Debug)]
80pub struct StructField {
81    pub is_pub: bool,
82    pub name: Spanned<String>,
83    pub ty: Spanned<Type>,
84}
85
86#[derive(Clone, Debug)]
87pub struct EventDef {
88    pub cfg: Option<Spanned<String>>,
89    pub name: Spanned<String>,
90    pub fields: Vec<EventField>,
91}
92
93#[derive(Clone, Debug)]
94pub struct EventField {
95    pub name: Spanned<String>,
96    pub ty: Spanned<Type>,
97}
98
99#[derive(Clone, Debug)]
100pub struct FnDef {
101    pub is_pub: bool,
102    pub cfg: Option<Spanned<String>>,
103    pub intrinsic: Option<Spanned<String>>,
104    pub is_test: bool,
105    /// Pure annotation: `#[pure]` — no I/O side effects allowed.
106    pub is_pure: bool,
107    /// Precondition annotations: `#[requires(predicate)]`.
108    pub requires: Vec<Spanned<String>>,
109    /// Postcondition annotations: `#[ensures(predicate)]`.
110    pub ensures: Vec<Spanned<String>>,
111    pub name: Spanned<String>,
112    /// Size-generic parameters, e.g. `<N>` in `fn sum<N>(arr: [Field; N])`.
113    pub type_params: Vec<Spanned<String>>,
114    pub params: Vec<Param>,
115    pub return_ty: Option<Spanned<Type>>,
116    pub body: Option<Spanned<Block>>,
117}
118
119#[derive(Clone, Debug)]
120pub struct Param {
121    pub name: Spanned<String>,
122    pub ty: Spanned<Type>,
123}
124
125/// Array size: a compile-time expression over literals and generic size parameters.
126#[derive(Clone, Debug, PartialEq, Eq)]
127pub enum ArraySize {
128    Literal(u64),
129    Param(String),
130    /// Compile-time addition: `M + N` or `N + 1`.
131    Add(Box<ArraySize>, Box<ArraySize>),
132    /// Compile-time multiplication: `M * N` or `N * 2`.
133    Mul(Box<ArraySize>, Box<ArraySize>),
134}
135
136impl ArraySize {
137    /// Return the concrete size, or `None` for unresolved params/expressions.
138    pub fn as_literal(&self) -> Option<u64> {
139        match self {
140            ArraySize::Literal(n) => Some(*n),
141            ArraySize::Add(a, b) => Some(a.as_literal()?.saturating_add(b.as_literal()?)),
142            ArraySize::Mul(a, b) => Some(a.as_literal()?.saturating_mul(b.as_literal()?)),
143            ArraySize::Param(_) => None,
144        }
145    }
146
147    /// Evaluate with substitutions for size parameters.
148    pub fn eval(&self, subs: &std::collections::BTreeMap<String, u64>) -> u64 {
149        match self {
150            ArraySize::Literal(n) => *n,
151            ArraySize::Param(name) => subs.get(name).copied().unwrap_or(0),
152            ArraySize::Add(a, b) => a.eval(subs).saturating_add(b.eval(subs)),
153            ArraySize::Mul(a, b) => a.eval(subs).saturating_mul(b.eval(subs)),
154        }
155    }
156}
157
158impl std::fmt::Display for ArraySize {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            ArraySize::Literal(n) => write!(f, "{}", n),
162            ArraySize::Param(name) => write!(f, "{}", name),
163            ArraySize::Add(a, b) => write!(f, "{} + {}", a, b),
164            ArraySize::Mul(a, b) => {
165                // Parenthesize addition inside multiplication
166                let a_str = if matches!(a.as_ref(), ArraySize::Add(..)) {
167                    format!("({})", a)
168                } else {
169                    format!("{}", a)
170                };
171                let b_str = if matches!(b.as_ref(), ArraySize::Add(..)) {
172                    format!("({})", b)
173                } else {
174                    format!("{}", b)
175                };
176                write!(f, "{} * {}", a_str, b_str)
177            }
178        }
179    }
180}
181
182/// Syntactic types (as written in source).
183#[derive(Clone, Debug, PartialEq, Eq)]
184pub enum Type {
185    Field,
186    XField,
187    Bool,
188    U32,
189    Digest,
190    Array(Box<Type>, ArraySize),
191    Tuple(Vec<Type>),
192    Named(ModulePath),
193}
194
195/// A block of statements with an optional trailing expression.
196#[derive(Clone, Debug)]
197pub struct Block {
198    pub stmts: Vec<Spanned<Stmt>>,
199    pub tail_expr: Option<Box<Spanned<Expr>>>,
200}
201
202/// A binding pattern for `let` statements.
203#[derive(Clone, Debug)]
204pub enum Pattern {
205    /// Single name: `let x = ...`
206    Name(Spanned<String>),
207    /// Tuple destructure: `let (a, b) = ...`
208    Tuple(Vec<Spanned<String>>),
209}
210
211/// A pattern in a match arm.
212#[derive(Clone, Debug)]
213pub enum MatchPattern {
214    /// Integer or boolean literal: `0`, `42`, `true`, `false`.
215    Literal(Literal),
216    /// Wildcard: `_`.
217    Wildcard,
218    /// Struct destructuring: `Point { x, y }` or `Point { x: a, y: 0 }`.
219    /// Each field maps to a `StructPatternField`.
220    Struct {
221        name: Spanned<String>,
222        fields: Vec<StructPatternField>,
223    },
224}
225
226/// A field in a struct destructuring pattern.
227#[derive(Clone, Debug)]
228pub struct StructPatternField {
229    /// The struct field name being matched.
230    pub field_name: Spanned<String>,
231    /// The pattern for this field: a binding name, a literal, or wildcard.
232    pub pattern: Spanned<FieldPattern>,
233}
234
235/// What a struct pattern field matches against.
236#[derive(Clone, Debug)]
237pub enum FieldPattern {
238    /// Bind to a variable: `x` (shorthand) or `x: var_name`.
239    Binding(String),
240    /// Match a literal value: `x: 0` or `x: true`.
241    Literal(Literal),
242    /// Wildcard: `x: _`.
243    Wildcard,
244}
245
246/// A single arm in a match statement.
247#[derive(Clone, Debug)]
248pub struct MatchArm {
249    pub pattern: Spanned<MatchPattern>,
250    pub body: Spanned<Block>,
251}
252
253/// Statements.
254#[derive(Clone, Debug)]
255pub enum Stmt {
256    Let {
257        mutable: bool,
258        pattern: Pattern,
259        ty: Option<Spanned<Type>>,
260        init: Spanned<Expr>,
261    },
262    Assign {
263        place: Spanned<Place>,
264        value: Spanned<Expr>,
265    },
266    TupleAssign {
267        names: Vec<Spanned<String>>,
268        value: Spanned<Expr>,
269    },
270    If {
271        cond: Spanned<Expr>,
272        then_block: Spanned<Block>,
273        else_block: Option<Spanned<Block>>,
274    },
275    For {
276        var: Spanned<String>,
277        start: Spanned<Expr>,
278        end: Spanned<Expr>,
279        bound: Option<u64>,
280        body: Spanned<Block>,
281    },
282    Expr(Spanned<Expr>),
283    Return(Option<Spanned<Expr>>),
284    Reveal {
285        event_name: Spanned<String>,
286        fields: Vec<(Spanned<String>, Spanned<Expr>)>,
287    },
288    Seal {
289        event_name: Spanned<String>,
290        fields: Vec<(Spanned<String>, Spanned<Expr>)>,
291    },
292    Asm {
293        body: String,
294        effect: i32,
295        target: Option<String>,
296    },
297    Match {
298        expr: Spanned<Expr>,
299        arms: Vec<MatchArm>,
300    },
301}
302
303/// L-value places (can appear on left side of assignment).
304#[derive(Clone, Debug)]
305pub enum Place {
306    Var(String),
307    FieldAccess(Box<Spanned<Place>>, Spanned<String>),
308    Index(Box<Spanned<Place>>, Box<Spanned<Expr>>),
309}
310
311/// Expressions.
312#[derive(Clone, Debug)]
313pub enum Expr {
314    Literal(Literal),
315    Var(String),
316    BinOp {
317        op: BinOp,
318        lhs: Box<Spanned<Expr>>,
319        rhs: Box<Spanned<Expr>>,
320    },
321    Call {
322        path: Spanned<ModulePath>,
323        /// Explicit size-generic arguments, e.g. `sum<3>(...)`.
324        generic_args: Vec<Spanned<ArraySize>>,
325        args: Vec<Spanned<Expr>>,
326    },
327    FieldAccess {
328        expr: Box<Spanned<Expr>>,
329        field: Spanned<String>,
330    },
331    Index {
332        expr: Box<Spanned<Expr>>,
333        index: Box<Spanned<Expr>>,
334    },
335    StructInit {
336        path: Spanned<ModulePath>,
337        fields: Vec<(Spanned<String>, Spanned<Expr>)>,
338    },
339    ArrayInit(Vec<Spanned<Expr>>),
340    Tuple(Vec<Spanned<Expr>>),
341}
342
343#[derive(Clone, Debug, PartialEq, Eq)]
344pub enum Literal {
345    Integer(u64),
346    Bool(bool),
347}
348
349#[derive(Clone, Copy, Debug, PartialEq, Eq)]
350pub enum BinOp {
351    Add,       // +
352    Mul,       // *
353    Eq,        // ==
354    Lt,        // <
355    BitAnd,    // &
356    BitXor,    // ^
357    DivMod,    // /%
358    XFieldMul, // *.
359}
360
361impl BinOp {
362    pub fn as_str(&self) -> &'static str {
363        match self {
364            BinOp::Add => "+",
365            BinOp::Mul => "*",
366            BinOp::Eq => "==",
367            BinOp::Lt => "<",
368            BinOp::BitAnd => "&",
369            BinOp::BitXor => "^",
370            BinOp::DivMod => "/%",
371            BinOp::XFieldMul => "*.",
372        }
373    }
374
375    /// Pratt binding power: (left, right). Higher binds tighter.
376    /// Single source of truth for both parser and formatter.
377    pub fn binding_power(&self) -> (u8, u8) {
378        match self {
379            BinOp::Eq => (2, 3),
380            BinOp::Lt => (4, 5),
381            BinOp::Add => (6, 7),
382            BinOp::Mul | BinOp::XFieldMul => (8, 9),
383            BinOp::BitAnd | BinOp::BitXor => (10, 11),
384            BinOp::DivMod => (12, 13),
385        }
386    }
387}