wac_parser/
ast.rs

1//! Module for the AST implementation.
2
3use crate::{
4    lexer::{self, Lexer, LexerResult, Token},
5    resolution::{AstResolver, Resolution, ResolutionResult},
6};
7use indexmap::IndexMap;
8use miette::{Diagnostic, SourceSpan};
9use serde::Serialize;
10use std::fmt;
11use wac_graph::types::BorrowedPackageKey;
12
13mod export;
14mod expr;
15mod import;
16mod r#let;
17mod printer;
18mod r#type;
19
20pub use export::*;
21pub use expr::*;
22pub use import::*;
23pub use printer::*;
24pub use r#let::*;
25pub use r#type::*;
26
27struct Found(Option<Token>);
28
29impl fmt::Display for Found {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self.0 {
32            Some(t) => t.fmt(f),
33            None => write!(f, "end of input"),
34        }
35    }
36}
37
38struct Expected<'a> {
39    expected: &'a [Option<Token>],
40    count: usize,
41}
42
43impl fmt::Display for Expected<'_> {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        let mut expected = self.expected.iter().enumerate();
46        while let Some((i, Some(token))) = expected.next() {
47            if i > 0 {
48                write!(f, ", ")?;
49            }
50
51            if i == self.count - 1 {
52                write!(f, "or ")?;
53            }
54
55            token.fmt(f)?;
56        }
57
58        if self.count > self.expected.len() {
59            write!(f, ", or more...")?;
60        }
61
62        Ok(())
63    }
64}
65
66/// Represents a parse error.
67#[derive(thiserror::Error, Diagnostic, Debug)]
68#[diagnostic(code("failed to parse document"))]
69pub enum Error {
70    /// A lexer error occurred.
71    #[error("{error}")]
72    Lexer {
73        /// The lexer error that occurred.
74        error: crate::lexer::Error,
75        /// The span where the error occurred.
76        #[label(primary)]
77        span: SourceSpan,
78    },
79    /// An unexpected token was encountered when a single token was expected.
80    #[error("expected {expected}, found {found}", found = Found(*.found))]
81    Expected {
82        /// The expected token.
83        expected: Token,
84        /// The found token (`None` for end of input).
85        found: Option<Token>,
86        /// The span of the found token.
87        #[label(primary, "unexpected {found}", found = Found(*.found))]
88        span: SourceSpan,
89    },
90    /// An unexpected token was encountered when either one of two tokens was expected.
91    #[error("expected {first} or {second}, found {found}", found = Found(*.found))]
92    ExpectedEither {
93        /// The first expected token.
94        first: Token,
95        /// The second expected token.
96        second: Token,
97        /// The found token.
98        found: Option<Token>,
99        /// The span of the found token.
100        #[label(primary, "unexpected {found}", found = Found(*.found))]
101        span: SourceSpan,
102    },
103    /// An unexpected token was encountered when multiple tokens were expected.
104    #[error("expected either {expected}, found {found}", expected = Expected { expected, count: *.count }, found = Found(*.found))]
105    ExpectedMultiple {
106        /// The tokens that were expected.
107        expected: [Option<Token>; 10],
108        /// The count of expected tokens.
109        count: usize,
110        /// The found token.
111        found: Option<Token>,
112        /// The span of the found token.
113        #[label(primary, "unexpected {found}", found = Found(*.found))]
114        span: SourceSpan,
115    },
116    /// An empty type was encountered.
117    #[error("{ty} must contain at least one {kind}")]
118    EmptyType {
119        /// The type that was empty (e.g. "record", "variant", etc.)
120        ty: &'static str,
121        /// The kind of item that was empty (e.g. "field", "case", etc.)
122        kind: &'static str,
123        /// The span of the empty type.
124        #[label(primary, "empty {ty}")]
125        span: SourceSpan,
126    },
127    /// An invalid semantic version was encountered.
128    #[error("`{version}` is not a valid semantic version")]
129    InvalidVersion {
130        /// The invalid version.
131        version: std::string::String,
132        /// The span of the version.
133        #[label(primary, "invalid version")]
134        span: SourceSpan,
135    },
136}
137
138/// Represents a parse result.
139pub type ParseResult<T> = Result<T, Error>;
140
141impl From<(lexer::Error, SourceSpan)> for Error {
142    fn from((e, s): (lexer::Error, SourceSpan)) -> Self {
143        Self::Lexer { error: e, span: s }
144    }
145}
146
147/// Expects a given token from the lexer.
148pub fn parse_token(lexer: &mut Lexer, expected: Token) -> ParseResult<SourceSpan> {
149    let (result, span) = lexer.next().ok_or_else(|| Error::Expected {
150        expected,
151        found: None,
152        span: lexer.span(),
153    })?;
154
155    match result {
156        Ok(found) if found == expected => Ok(span),
157        Ok(found) => Err(Error::Expected {
158            expected,
159            found: Some(found),
160            span,
161        }),
162        Err(e) => Err((e, span).into()),
163    }
164}
165
166/// Parses an optional tokens from a lexer.
167///
168/// The expected token is removed from the token stream before the callback is invoked.
169pub fn parse_optional<'a, F, R>(
170    lexer: &mut Lexer<'a>,
171    expected: Token,
172    cb: F,
173) -> ParseResult<Option<R>>
174where
175    F: FnOnce(&mut Lexer<'a>) -> ParseResult<R>,
176{
177    match lexer.peek() {
178        Some((Ok(token), _)) => {
179            if token == expected {
180                lexer.next();
181                Ok(Some(cb(lexer)?))
182            } else {
183                Ok(None)
184            }
185        }
186        Some((Err(e), s)) => Err((e, s).into()),
187        None => Ok(None),
188    }
189}
190
191/// Used to look ahead one token in the lexer.
192///
193/// The lookahead stores up to 10 attempted tokens.
194pub struct Lookahead {
195    next: Option<(LexerResult<Token>, SourceSpan)>,
196    attempts: [Option<Token>; 10],
197    span: SourceSpan,
198    count: usize,
199}
200
201impl Lookahead {
202    /// Creates a new lookahead from the given lexer.
203    pub fn new(lexer: &Lexer) -> Self {
204        Self {
205            next: lexer.peek(),
206            span: lexer.span(),
207            attempts: Default::default(),
208            count: 0,
209        }
210    }
211
212    /// Peeks to see if the next token matches the given token.
213    pub fn peek(&mut self, expected: Token) -> bool {
214        match &self.next {
215            Some((Ok(t), _)) if *t == expected => true,
216            _ => {
217                if self.count < self.attempts.len() {
218                    self.attempts[self.count] = Some(expected);
219                }
220
221                self.count += 1;
222                false
223            }
224        }
225    }
226
227    /// Returns an error based on the attempted tokens.
228    ///
229    /// Panics if no peeks were attempted.
230    pub fn error(self) -> Error {
231        let (found, span) = match self.next {
232            Some((Ok(token), span)) => (Some(token), span),
233            Some((Err(e), s)) => return (e, s).into(),
234            None => (None, self.span),
235        };
236
237        match self.count {
238            0 => unreachable!("lookahead had no attempts"),
239            1 => Error::Expected {
240                expected: self.attempts[0].unwrap(),
241                found,
242                span,
243            },
244            2 => Error::ExpectedEither {
245                first: self.attempts[0].unwrap(),
246                second: self.attempts[1].unwrap(),
247                found,
248                span,
249            },
250            _ => Error::ExpectedMultiple {
251                expected: self.attempts,
252                count: self.count,
253                found,
254                span,
255            },
256        }
257    }
258}
259
260pub(crate) trait Parse<'a>: Sized {
261    fn parse(lexer: &mut Lexer<'a>) -> ParseResult<Self>;
262}
263
264trait Peek {
265    fn peek(lookahead: &mut Lookahead) -> bool;
266}
267
268fn parse_delimited<'a, T: Parse<'a> + Peek>(
269    lexer: &mut Lexer<'a>,
270    until: Token,
271    with_commas: bool,
272) -> ParseResult<Vec<T>> {
273    let mut items = Vec::new();
274    loop {
275        let mut lookahead = Lookahead::new(lexer);
276        if lookahead.peek(until) {
277            break;
278        }
279
280        if !T::peek(&mut lookahead) {
281            return Err(lookahead.error());
282        }
283
284        items.push(Parse::parse(lexer)?);
285
286        if let Some((Ok(next), _)) = lexer.peek() {
287            if next == until {
288                break;
289            }
290
291            if with_commas {
292                parse_token(lexer, Token::Comma)?;
293            }
294        }
295    }
296
297    Ok(items)
298}
299
300/// Represents a package directive in the AST.
301#[derive(Debug, Clone, Serialize)]
302#[serde(rename_all = "camelCase")]
303pub struct PackageDirective<'a> {
304    /// The name of the package named by the directive.
305    pub package: PackageName<'a>,
306    /// The optional world being targeted by the package.
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub targets: Option<PackagePath<'a>>,
309}
310
311impl<'a> Parse<'a> for PackageDirective<'a> {
312    fn parse(lexer: &mut Lexer<'a>) -> ParseResult<Self> {
313        parse_token(lexer, Token::PackageKeyword)?;
314        let package = Parse::parse(lexer)?;
315        let targets = parse_optional(lexer, Token::TargetsKeyword, Parse::parse)?;
316        parse_token(lexer, Token::Semicolon)?;
317        Ok(Self { package, targets })
318    }
319}
320
321/// Represents a top-level WAC document.
322#[derive(Debug, Clone, Serialize)]
323#[serde(rename_all = "camelCase")]
324pub struct Document<'a> {
325    /// The doc comments for the package.
326    pub docs: Vec<DocComment<'a>>,
327    /// The package directive of the document.
328    pub directive: PackageDirective<'a>,
329    /// The statements in the document.
330    pub statements: Vec<Statement<'a>>,
331}
332
333impl<'a> Document<'a> {
334    /// Parses the given source string as a document.
335    ///
336    /// The given path is used for error reporting.
337    pub fn parse(source: &'a str) -> ParseResult<Self> {
338        let mut lexer = Lexer::new(source).map_err(Error::from)?;
339
340        let docs = Parse::parse(&mut lexer)?;
341        let directive = Parse::parse(&mut lexer)?;
342
343        let mut statements: Vec<Statement> = Default::default();
344        while lexer.peek().is_some() {
345            statements.push(Parse::parse(&mut lexer)?);
346        }
347
348        assert!(lexer.next().is_none(), "expected all tokens to be consumed");
349        Ok(Self {
350            docs,
351            directive,
352            statements,
353        })
354    }
355
356    /// Resolves the document.
357    ///
358    /// The returned resolution contains an encodable composition graph.
359    pub fn resolve(
360        &self,
361        packages: IndexMap<BorrowedPackageKey<'a>, Vec<u8>>,
362    ) -> ResolutionResult<Resolution> {
363        AstResolver::new(self).resolve(packages)
364    }
365}
366
367/// Represents a statement in the AST.
368#[derive(Debug, Clone, Serialize)]
369pub enum Statement<'a> {
370    /// An import statement.
371    Import(ImportStatement<'a>),
372    /// A type statement.
373    Type(TypeStatement<'a>),
374    /// A let statement.
375    Let(LetStatement<'a>),
376    /// An export statement.
377    Export(ExportStatement<'a>),
378}
379
380impl<'a> Parse<'a> for Statement<'a> {
381    fn parse(lexer: &mut Lexer<'a>) -> ParseResult<Self> {
382        let mut lookahead = Lookahead::new(lexer);
383        if ImportStatement::peek(&mut lookahead) {
384            Ok(Self::Import(Parse::parse(lexer)?))
385        } else if LetStatement::peek(&mut lookahead) {
386            Ok(Self::Let(Parse::parse(lexer)?))
387        } else if ExportStatement::peek(&mut lookahead) {
388            Ok(Self::Export(Parse::parse(lexer)?))
389        } else if TypeStatement::peek(&mut lookahead) {
390            Ok(Self::Type(Parse::parse(lexer)?))
391        } else {
392            Err(lookahead.error())
393        }
394    }
395}
396
397/// Represents an identifier in the AST.
398#[derive(Debug, Clone, Copy, Serialize)]
399#[serde(rename_all = "camelCase")]
400pub struct Ident<'a> {
401    /// The identifier string.
402    pub string: &'a str,
403    /// The span of the identifier.
404    pub span: SourceSpan,
405}
406
407impl<'a> Parse<'a> for Ident<'a> {
408    fn parse(lexer: &mut Lexer<'a>) -> ParseResult<Self> {
409        let span = parse_token(lexer, Token::Ident)?;
410        let id = lexer.source(span);
411        Ok(Self {
412            string: id.strip_prefix('%').unwrap_or(id),
413            span,
414        })
415    }
416}
417
418impl Peek for Ident<'_> {
419    fn peek(lookahead: &mut Lookahead) -> bool {
420        lookahead.peek(Token::Ident)
421    }
422}
423
424/// Represents a string in the AST.
425#[derive(Debug, Copy, Clone, Serialize)]
426#[serde(rename_all = "camelCase")]
427pub struct String<'a> {
428    /// The value of the string (without quotes).
429    pub value: &'a str,
430    /// The span of the string.
431    pub span: SourceSpan,
432}
433
434impl<'a> Parse<'a> for String<'a> {
435    fn parse(lexer: &mut Lexer<'a>) -> ParseResult<Self> {
436        let span = parse_token(lexer, Token::String)?;
437        let s = lexer.source(span);
438        Ok(Self {
439            value: s.strip_prefix('"').unwrap().strip_suffix('"').unwrap(),
440            span,
441        })
442    }
443}
444
445impl Peek for String<'_> {
446    fn peek(lookahead: &mut Lookahead) -> bool {
447        lookahead.peek(Token::String)
448    }
449}
450
451/// Represents a documentation comment in the AST.
452#[derive(Debug, Clone, Serialize)]
453#[serde(rename_all = "camelCase")]
454pub struct DocComment<'a> {
455    /// The comment string.
456    pub comment: &'a str,
457    /// The span of the comment.
458    pub span: SourceSpan,
459}
460
461impl<'a> Parse<'a> for Vec<DocComment<'a>> {
462    fn parse(lexer: &mut Lexer<'a>) -> ParseResult<Vec<DocComment<'a>>> {
463        Ok(lexer
464            .comments()
465            .map_err(Error::from)?
466            .into_iter()
467            .map(|(comment, span)| DocComment { comment, span })
468            .collect())
469    }
470}
471
472#[cfg(test)]
473mod test {
474    use super::*;
475    use pretty_assertions::assert_eq;
476
477    /// A test function for parsing a string into an AST node,
478    /// converting the AST node into a string, and then comparing
479    /// the result with the given expected string.
480    ///
481    /// Note that we don't expect the input string to be the same
482    /// as the output string, since the input string may contain
483    /// extra whitespace and comments that are not preserved in
484    /// the AST.
485    pub(crate) fn roundtrip(source: &str, expected: &str) -> ParseResult<()> {
486        let doc = Document::parse(source)?;
487        let mut s = std::string::String::new();
488        DocumentPrinter::new(&mut s, source, None)
489            .document(&doc)
490            .unwrap();
491        assert_eq!(s, expected, "unexpected AST output");
492        Ok(())
493    }
494
495    #[test]
496    fn document_roundtrip() {
497        roundtrip(
498            r#"/* ignore me */
499/// Doc comment for the package!
500package test:foo:bar@1.0.0;
501/// Doc comment #1!
502import foo: foo:bar/baz;
503/// Doc comment #2!
504type a = u32;
505/// Doc comment #3!
506record r {
507    x: string
508}
509/// Doc comment #4!
510interface i {
511    /// Doc comment #5!
512    f: func() -> r;
513}
514/// Doc comment #6!
515world w {
516    /// Doc comment #7!
517    import i;
518    /// Doc comment #8!
519    export f: func() -> a;
520}
521/// Doc comment #9!
522let x = new foo:bar { };
523/// Doc comment #10!
524export x as "foo";
525"#,
526            r#"/// Doc comment for the package!
527package test:foo:bar@1.0.0;
528
529/// Doc comment #1!
530import foo: foo:bar/baz;
531
532/// Doc comment #2!
533type a = u32;
534
535/// Doc comment #3!
536record r {
537    x: string,
538}
539
540/// Doc comment #4!
541interface i {
542    /// Doc comment #5!
543    f: func() -> r;
544}
545
546/// Doc comment #6!
547world w {
548    /// Doc comment #7!
549    import i;
550
551    /// Doc comment #8!
552    export f: func() -> a;
553}
554
555/// Doc comment #9!
556let x = new foo:bar {};
557
558/// Doc comment #10!
559export x as "foo";
560"#,
561        )
562        .unwrap()
563    }
564}