Skip to main content

sim_kernel/
expr.rs

1//! The [`Expr`] graph: the checked-form representation of source.
2//!
3//! The kernel defines the expression graph, number literals, quote modes, and
4//! source-origin tracking; general-purpose codecs in library crates round-trip
5//! every expression through this shared graph.
6
7use std::{
8    collections::BTreeMap,
9    hash::{Hash, Hasher},
10    sync::Arc,
11};
12
13use crate::id::{CodecId, Symbol};
14
15/// The codec-neutral expression graph: the checked form of source.
16///
17/// The kernel defines this graph; general-purpose codecs in library crates
18/// round-trip every expression through it. It is the `checked forms` stage of
19/// the data flow `tokens -> checked forms -> objects -> checked calls ->
20/// objects -> encoded forms`. Equality and hashing are canonical (see
21/// [`Expr::canonical_eq`]): map entries and set members compare order-insensitively.
22///
23/// # Examples
24///
25/// ```
26/// # use sim_kernel::expr::Expr;
27/// # use sim_kernel::id::Symbol;
28/// let call = Expr::Call {
29///     operator: Box::new(Expr::Symbol(Symbol::new("add"))),
30///     args: vec![Expr::Bool(true), Expr::Nil],
31/// };
32/// // Maps compare canonically regardless of entry order.
33/// let a = Expr::Set(vec![Expr::Bool(true), Expr::Nil]);
34/// let b = Expr::Set(vec![Expr::Nil, Expr::Bool(true)]);
35/// assert!(a.canonical_eq(&b));
36/// let _ = call;
37/// ```
38#[derive(Clone, Debug)]
39pub enum Expr {
40    /// The nil literal.
41    Nil,
42    /// A boolean literal.
43    Bool(bool),
44    /// A number literal in some domain.
45    Number(NumberLiteral),
46    /// A symbol reference.
47    Symbol(Symbol),
48    /// A lexical local reference.
49    Local(Symbol),
50    /// A string literal.
51    String(String),
52    /// A byte-string literal.
53    Bytes(Vec<u8>),
54    /// An ordered list form.
55    List(Vec<Expr>),
56    /// An ordered vector form.
57    Vector(Vec<Expr>),
58    /// A map form of key/value pairs (order-insensitive when compared).
59    Map(Vec<(Expr, Expr)>),
60    /// A set form (order-insensitive when compared).
61    Set(Vec<Expr>),
62    /// A call of `operator` with positional `args`.
63    Call {
64        /// The expression producing the callable operator.
65        operator: Box<Expr>,
66        /// The positional argument expressions.
67        args: Vec<Expr>,
68    },
69    /// An infix operator application.
70    Infix {
71        /// The operator symbol.
72        operator: Symbol,
73        /// The left operand.
74        left: Box<Expr>,
75        /// The right operand.
76        right: Box<Expr>,
77    },
78    /// A prefix operator application.
79    Prefix {
80        /// The operator symbol.
81        operator: Symbol,
82        /// The operand.
83        arg: Box<Expr>,
84    },
85    /// A postfix operator application.
86    Postfix {
87        /// The operator symbol.
88        operator: Symbol,
89        /// The operand.
90        arg: Box<Expr>,
91    },
92    /// A sequenced block of expressions.
93    Block(Vec<Expr>),
94    /// A quoted expression carried at a given quote mode.
95    Quote {
96        /// The quote mode (quote, quasiquote, ...).
97        mode: QuoteMode,
98        /// The quoted expression.
99        expr: Box<Expr>,
100    },
101    /// An expression carrying named annotations.
102    Annotated {
103        /// The annotated inner expression.
104        expr: Box<Expr>,
105        /// The name/value annotation pairs.
106        annotations: Vec<(Symbol, Expr)>,
107    },
108    /// An open extension form with a tag and payload.
109    Extension {
110        /// The extension tag.
111        tag: Symbol,
112        /// The extension payload expression.
113        payload: Box<Expr>,
114    },
115}
116
117/// A number literal: a domain symbol plus its canonical textual form.
118///
119/// The kernel carries the literal verbatim; concrete number domains and
120/// arithmetic are supplied by libraries.
121#[derive(Clone, Debug, PartialEq, Eq, Hash)]
122pub struct NumberLiteral {
123    /// The number domain naming how `canonical` is interpreted.
124    pub domain: Symbol,
125    /// The canonical textual representation of the number.
126    pub canonical: String,
127}
128
129/// The quoting mode attached to a [`Expr::Quote`] form.
130#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
131pub enum QuoteMode {
132    /// Plain quotation.
133    Quote,
134    /// Quasiquotation, allowing nested unquotes.
135    QuasiQuote,
136    /// Unquote within a quasiquote.
137    Unquote,
138    /// Splicing unquote within a quasiquote.
139    Splice,
140    /// Hygienic syntax quotation.
141    Syntax,
142}
143
144/// An [`Expr`] paired with its optional source [`Origin`].
145#[derive(Clone, Debug, PartialEq, Eq, Hash)]
146pub struct LocatedExpr {
147    /// The expression.
148    pub expr: Expr,
149    /// The optional source origin.
150    pub origin: Option<Origin>,
151}
152
153impl LocatedExpr {
154    /// Compares two located expressions canonically, ignoring origin.
155    pub fn canonical_eq(&self, other: &Self) -> bool {
156        self.expr.canonical_eq(&other.expr)
157    }
158}
159
160/// A located expression together with its located children.
161///
162/// Carries per-node [`Origin`] for a whole expression tree, so codecs can
163/// preserve source spans down to each sub-form.
164#[derive(Clone, Debug, PartialEq, Eq, Hash)]
165pub struct LocatedExprTree {
166    /// The expression at this node.
167    pub expr: Expr,
168    /// The optional source origin of this node.
169    pub origin: Option<Origin>,
170    /// The located child nodes.
171    pub children: Vec<LocatedExprTree>,
172}
173
174impl LocatedExprTree {
175    /// Compares two trees canonically, ignoring origin, including children.
176    pub fn canonical_eq(&self, other: &Self) -> bool {
177        self.expr.canonical_eq(&other.expr)
178            && self.children.len() == other.children.len()
179            && self
180                .children
181                .iter()
182                .zip(other.children.iter())
183                .all(|(left, right)| left.canonical_eq(right))
184    }
185
186    /// Projects this node to a flat [`LocatedExpr`], dropping children.
187    pub fn located(&self) -> LocatedExpr {
188        LocatedExpr {
189            expr: self.expr.clone(),
190            origin: self.origin.clone(),
191        }
192    }
193
194    /// Builds a leaf tree node with no children.
195    pub fn without_children(expr: Expr, origin: Option<Origin>) -> Self {
196        Self {
197            expr,
198            origin,
199            children: Vec::new(),
200        }
201    }
202
203    /// Lifts a flat [`LocatedExpr`] into a leaf tree node.
204    pub fn from_located(located: LocatedExpr) -> Self {
205        Self {
206            expr: located.expr,
207            origin: located.origin,
208            children: Vec::new(),
209        }
210    }
211
212    /// Builds a tree from `expr` by recursively wrapping its children, with no origins.
213    pub fn from_expr_recursive(expr: Expr) -> Self {
214        let children = expr_children(&expr)
215            .into_iter()
216            .map(Self::from_expr_recursive)
217            .collect();
218        Self {
219            expr,
220            origin: None,
221            children,
222        }
223    }
224}
225
226fn expr_children(expr: &Expr) -> Vec<Expr> {
227    match expr {
228        Expr::Nil
229        | Expr::Bool(_)
230        | Expr::Number(_)
231        | Expr::Symbol(_)
232        | Expr::Local(_)
233        | Expr::String(_)
234        | Expr::Bytes(_) => Vec::new(),
235        Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
236            items.clone()
237        }
238        Expr::Map(entries) => entries
239            .iter()
240            .flat_map(|(key, value)| [key.clone(), value.clone()])
241            .collect(),
242        Expr::Call { operator, args } => std::iter::once(operator.as_ref().clone())
243            .chain(args.iter().cloned())
244            .collect(),
245        Expr::Infix { left, right, .. } => vec![left.as_ref().clone(), right.as_ref().clone()],
246        Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => vec![arg.as_ref().clone()],
247        Expr::Quote { expr, .. } => vec![expr.as_ref().clone()],
248        Expr::Annotated { expr, annotations } => std::iter::once(expr.as_ref().clone())
249            .chain(annotations.iter().map(|(_, value)| value.clone()))
250            .collect(),
251        Expr::Extension { payload, .. } => vec![payload.as_ref().clone()],
252    }
253}
254
255/// Source provenance of an expression: which codec, source, span, and trivia.
256#[derive(Clone, Debug, PartialEq, Eq, Hash)]
257pub struct Origin {
258    /// The codec that read the expression.
259    pub codec: CodecId,
260    /// The source the expression was read from.
261    pub source: SourceId,
262    /// The byte span within the source.
263    pub span: Span,
264    /// The surrounding whitespace and comment trivia.
265    pub trivia: Vec<Trivia>,
266}
267
268/// Opaque identifier of a registered source.
269#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
270pub struct SourceId(pub String);
271
272/// A store mapping each [`SourceId`] to its raw source bytes.
273///
274/// Lets diagnostics and lossless codecs recover original text by [`Origin`].
275#[derive(Clone, Debug, Default)]
276pub struct SourceRegistry {
277    sources: BTreeMap<SourceId, Arc<[u8]>>,
278}
279
280impl SourceRegistry {
281    /// Inserts source bytes, returning any previous bytes for the same id.
282    pub fn insert(&mut self, source: SourceId, bytes: Arc<[u8]>) -> Option<Arc<[u8]>> {
283        self.sources.insert(source, bytes)
284    }
285
286    /// Interns raw bytes under `source`, replacing any prior entry.
287    pub fn intern_bytes(&mut self, source: SourceId, bytes: impl Into<Arc<[u8]>>) {
288        self.sources.insert(source, bytes.into());
289    }
290
291    /// Interns text under `source` as its UTF-8 bytes.
292    pub fn intern_text(&mut self, source: SourceId, text: &str) {
293        self.intern_bytes(source, Arc::<[u8]>::from(text.as_bytes()));
294    }
295
296    /// Splices `bytes` into the stored source at the origin's span.
297    pub fn intern_span(&mut self, origin: &Origin, bytes: &[u8]) {
298        let required_len = origin.span.end.max(bytes.len());
299        let mut merged = self
300            .sources
301            .get(&origin.source)
302            .map(|existing| existing.as_ref().to_vec())
303            .unwrap_or_default();
304        if merged.len() < required_len {
305            merged.resize(required_len, 0);
306        }
307        let end = origin.span.start.saturating_add(bytes.len());
308        if end <= merged.len() {
309            merged[origin.span.start..end].copy_from_slice(bytes);
310            self.sources
311                .insert(origin.source.clone(), Arc::<[u8]>::from(merged));
312        }
313    }
314
315    /// Returns the stored bytes for `source`, if any.
316    pub fn get(&self, source: &SourceId) -> Option<&Arc<[u8]>> {
317        self.sources.get(source)
318    }
319
320    /// Returns the byte slice covered by `origin`'s span, if available.
321    pub fn slice(&self, origin: &Origin) -> Option<&[u8]> {
322        self.sources
323            .get(&origin.source)
324            .and_then(|bytes| bytes.get(origin.span.start..origin.span.end))
325    }
326}
327
328/// A half-open byte range `[start, end)` within a source.
329#[derive(Clone, Debug, PartialEq, Eq, Hash)]
330pub struct Span {
331    /// The inclusive start byte offset.
332    pub start: usize,
333    /// The exclusive end byte offset.
334    pub end: usize,
335}
336
337/// Non-semantic source trivia preserved alongside an expression.
338#[derive(Clone, Debug, PartialEq, Eq, Hash)]
339pub enum Trivia {
340    /// Whitespace text.
341    Whitespace(String),
342    /// A single-line comment.
343    LineComment(String),
344    /// A block comment.
345    BlockComment(String),
346}
347
348/// A normalized, comparable key derived from an [`Expr`].
349///
350/// Backs canonical equality and hashing: structurally equal expressions
351/// (modulo map/set ordering) produce equal keys. Each variant tags its shape
352/// with a static label so unlike shapes never collide.
353#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
354pub enum CanonicalKey {
355    /// A tagged atom with no payload.
356    Atom(&'static str),
357    /// A tagged boolean.
358    Bool(&'static str, bool),
359    /// Tagged raw bytes.
360    Bytes(&'static str, Vec<u8>),
361    /// A tagged string.
362    String(&'static str, String),
363    /// A tagged symbol.
364    Symbol(&'static str, Symbol),
365    /// A tagged pair of strings.
366    Pair(&'static str, String, String),
367    /// A tagged ordered sequence of sub-keys.
368    Compound(&'static str, Vec<CanonicalKey>),
369    /// A tagged sequence of name/sub-key pairs.
370    CompoundNamed(&'static str, Vec<(String, CanonicalKey)>),
371}
372
373impl CanonicalKey {
374    fn tag(tag: &'static str) -> Self {
375        Self::Atom(tag)
376    }
377
378    fn with_bool(tag: &'static str, value: bool) -> Self {
379        Self::Bool(tag, value)
380    }
381
382    fn with_bytes(tag: &'static str, value: &[u8]) -> Self {
383        Self::Bytes(tag, value.to_vec())
384    }
385
386    fn with_string(tag: &'static str, value: &str) -> Self {
387        Self::String(tag, value.to_owned())
388    }
389
390    fn with_symbol(tag: &'static str, symbol: &Symbol) -> Self {
391        Self::Symbol(tag, symbol.clone())
392    }
393
394    fn with_pair(tag: &'static str, left: String, right: String) -> Self {
395        Self::Pair(tag, left, right)
396    }
397
398    fn compound(tag: &'static str, items: Vec<CanonicalKey>) -> Self {
399        Self::Compound(tag, items)
400    }
401
402    fn compound_named(tag: &'static str, items: Vec<(String, CanonicalKey)>) -> Self {
403        Self::CompoundNamed(tag, items)
404    }
405}
406
407impl Expr {
408    /// Computes the [`CanonicalKey`] backing canonical equality and hashing.
409    pub fn canonical_key(&self) -> CanonicalKey {
410        match self {
411            Self::Nil => CanonicalKey::tag("nil"),
412            Self::Bool(value) => CanonicalKey::with_bool("bool", *value),
413            Self::Number(value) => {
414                CanonicalKey::with_pair("number", value.domain.to_string(), value.canonical.clone())
415            }
416            Self::Symbol(symbol) => CanonicalKey::with_symbol("symbol", symbol),
417            Self::Local(symbol) => CanonicalKey::with_symbol("local", symbol),
418            Self::String(value) => CanonicalKey::with_string("string", value),
419            Self::Bytes(value) => CanonicalKey::with_bytes("bytes", value),
420            Self::List(items) => {
421                CanonicalKey::compound("list", items.iter().map(Self::canonical_key).collect())
422            }
423            Self::Vector(items) => {
424                CanonicalKey::compound("vector", items.iter().map(Self::canonical_key).collect())
425            }
426            Self::Map(entries) => {
427                let mut items = entries
428                    .iter()
429                    .map(|(key, value)| {
430                        CanonicalKey::compound(
431                            "entry",
432                            vec![key.canonical_key(), value.canonical_key()],
433                        )
434                    })
435                    .collect::<Vec<_>>();
436                items.sort();
437                CanonicalKey::compound("map", items)
438            }
439            Self::Set(items) => {
440                let mut items = items.iter().map(Self::canonical_key).collect::<Vec<_>>();
441                items.sort();
442                CanonicalKey::compound("set", items)
443            }
444            Self::Call { operator, args } => CanonicalKey::compound(
445                "call",
446                std::iter::once(operator.canonical_key())
447                    .chain(args.iter().map(Self::canonical_key))
448                    .collect(),
449            ),
450            Self::Infix {
451                operator,
452                left,
453                right,
454            } => CanonicalKey::compound_named(
455                "infix",
456                vec![
457                    (
458                        "operator".to_owned(),
459                        CanonicalKey::with_symbol("symbol", operator),
460                    ),
461                    ("left".to_owned(), left.canonical_key()),
462                    ("right".to_owned(), right.canonical_key()),
463                ],
464            ),
465            Self::Prefix { operator, arg } => CanonicalKey::compound_named(
466                "prefix",
467                vec![
468                    (
469                        "operator".to_owned(),
470                        CanonicalKey::with_symbol("symbol", operator),
471                    ),
472                    ("arg".to_owned(), arg.canonical_key()),
473                ],
474            ),
475            Self::Postfix { operator, arg } => CanonicalKey::compound_named(
476                "postfix",
477                vec![
478                    (
479                        "operator".to_owned(),
480                        CanonicalKey::with_symbol("symbol", operator),
481                    ),
482                    ("arg".to_owned(), arg.canonical_key()),
483                ],
484            ),
485            Self::Block(items) => {
486                CanonicalKey::compound("block", items.iter().map(Self::canonical_key).collect())
487            }
488            Self::Quote { mode, expr } => CanonicalKey::compound_named(
489                "quote",
490                vec![
491                    (
492                        "mode".to_owned(),
493                        CanonicalKey::with_string("quote-mode", &format!("{mode:?}")),
494                    ),
495                    ("expr".to_owned(), expr.canonical_key()),
496                ],
497            ),
498            Self::Annotated { expr, annotations } => CanonicalKey::compound_named(
499                "annotated",
500                std::iter::once(("expr".to_owned(), expr.canonical_key()))
501                    .chain(
502                        annotations
503                            .iter()
504                            .map(|(symbol, expr)| (symbol.to_string(), expr.canonical_key())),
505                    )
506                    .collect(),
507            ),
508            Self::Extension { tag, payload } => CanonicalKey::compound_named(
509                "extension",
510                vec![
511                    ("tag".to_owned(), CanonicalKey::with_symbol("symbol", tag)),
512                    ("payload".to_owned(), payload.canonical_key()),
513                ],
514            ),
515        }
516    }
517
518    /// Returns whether two expressions are canonically equal.
519    pub fn canonical_eq(&self, other: &Self) -> bool {
520        self.canonical_key() == other.canonical_key()
521    }
522}
523
524impl PartialEq for Expr {
525    fn eq(&self, other: &Self) -> bool {
526        self.canonical_key() == other.canonical_key()
527    }
528}
529
530impl Eq for Expr {}
531
532impl Hash for Expr {
533    fn hash<H: Hasher>(&self, state: &mut H) {
534        self.canonical_key().hash(state);
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use std::collections::hash_map::DefaultHasher;
541    use std::hash::{Hash, Hasher};
542
543    use crate::id::Symbol;
544
545    use super::{Expr, NumberLiteral};
546
547    #[test]
548    fn map_and_set_canonical_order_is_normalized() {
549        let key_a = Expr::Symbol(Symbol::new("a"));
550        let key_b = Expr::Symbol(Symbol::new("b"));
551        let value = Expr::Number(NumberLiteral {
552            domain: Symbol::qualified("numbers", "f64"),
553            canonical: "1.0".to_owned(),
554        });
555
556        let left = Expr::Map(vec![
557            (key_a.clone(), value.clone()),
558            (key_b.clone(), Expr::Nil),
559        ]);
560        let right = Expr::Map(vec![(key_b, Expr::Nil), (key_a, value)]);
561        assert!(left.canonical_eq(&right));
562
563        let left = Expr::Set(vec![Expr::String("x".to_owned()), Expr::Bool(true)]);
564        let right = Expr::Set(vec![Expr::Bool(true), Expr::String("x".to_owned())]);
565        assert!(left.canonical_eq(&right));
566    }
567
568    #[test]
569    fn expr_hash_matches_canonical_map_equality() {
570        let left = Expr::Map(vec![
571            (Expr::Symbol(Symbol::new("a")), Expr::Bool(true)),
572            (Expr::Symbol(Symbol::new("b")), Expr::Nil),
573        ]);
574        let right = Expr::Map(vec![
575            (Expr::Symbol(Symbol::new("b")), Expr::Nil),
576            (Expr::Symbol(Symbol::new("a")), Expr::Bool(true)),
577        ]);
578
579        assert_eq!(hash_expr(&left), hash_expr(&right));
580        assert_eq!(left, right);
581    }
582
583    #[test]
584    fn expr_hash_matches_canonical_set_equality() {
585        let left = Expr::Set(vec![Expr::String("x".to_owned()), Expr::Bool(true)]);
586        let right = Expr::Set(vec![Expr::Bool(true), Expr::String("x".to_owned())]);
587
588        assert_eq!(hash_expr(&left), hash_expr(&right));
589        assert_eq!(left, right);
590    }
591
592    fn hash_expr(expr: &Expr) -> u64 {
593        let mut hasher = DefaultHasher::new();
594        expr.hash(&mut hasher);
595        hasher.finish()
596    }
597}