wasm_wave/
ast.rs

1//! Abstract syntax tree types
2
3use std::{borrow::Cow, collections::HashMap, str::FromStr};
4
5use crate::{
6    lex::Span,
7    parser::{ParserError, ParserErrorKind},
8    strings::{StringPartsIter, unescape},
9    wasm::{WasmType, WasmTypeKind, WasmValue, WasmValueError},
10};
11
12/// A WAVE AST node.
13#[derive(Clone, Debug)]
14pub struct Node {
15    ty: NodeType,
16    span: Span,
17    children: Vec<Node>,
18}
19
20impl Node {
21    pub(crate) fn new(
22        ty: NodeType,
23        span: impl Into<Span>,
24        children: impl IntoIterator<Item = Node>,
25    ) -> Self {
26        Self {
27            ty,
28            span: span.into(),
29            children: Vec::from_iter(children),
30        }
31    }
32
33    /// Returns this node's type.
34    pub fn ty(&self) -> NodeType {
35        self.ty
36    }
37
38    /// Returns this node's span.
39    pub fn span(&self) -> Span {
40        self.span.clone()
41    }
42
43    /// Returns a bool value if this node represents a bool.
44    pub fn as_bool(&self) -> Result<bool, ParserError> {
45        match self.ty {
46            NodeType::BoolTrue => Ok(true),
47            NodeType::BoolFalse => Ok(false),
48            _ => Err(self.error(ParserErrorKind::InvalidType)),
49        }
50    }
51
52    /// Returns a number value of the given type (integer or float) if this node
53    /// can represent a number of that type.
54    pub fn as_number<T: FromStr>(&self, src: &str) -> Result<T, ParserError> {
55        self.ensure_type(NodeType::Number)?;
56        self.slice(src)
57            .parse()
58            .map_err(|_| self.error(ParserErrorKind::InvalidValue))
59    }
60
61    /// Returns a char value if this node represents a valid char.
62    pub fn as_char(&self, src: &str) -> Result<char, ParserError> {
63        self.ensure_type(NodeType::Char)?;
64        let inner = &src[self.span.start + 1..self.span.end - 1];
65        let (ch, len) = if inner.starts_with('\\') {
66            unescape(inner).ok_or_else(|| self.error(ParserErrorKind::InvalidEscape))?
67        } else {
68            let ch = inner.chars().next().unwrap();
69            (ch, ch.len_utf8())
70        };
71        // Verify length
72        if len != inner.len() {
73            return Err(self.error(ParserErrorKind::MultipleChars));
74        }
75        Ok(ch)
76    }
77
78    /// Returns a str value if this node represents a valid string.
79    pub fn as_str<'src>(&self, src: &'src str) -> Result<Cow<'src, str>, ParserError> {
80        let mut parts = self.iter_str(src)?;
81        let Some(first) = parts.next().transpose()? else {
82            return Ok("".into());
83        };
84        match parts.next().transpose()? {
85            // Single part may be borrowed
86            None => Ok(first),
87            // Multiple parts must be collected into a single owned String
88            Some(second) => {
89                let s: String = [Ok(first), Ok(second)]
90                    .into_iter()
91                    .chain(parts)
92                    .collect::<Result<_, _>>()?;
93                Ok(s.into())
94            }
95        }
96    }
97
98    /// Returns an iterator of string "parts" which together form a decoded
99    /// string value if this node represents a valid string.
100    pub fn iter_str<'src>(
101        &self,
102        src: &'src str,
103    ) -> Result<impl Iterator<Item = Result<Cow<'src, str>, ParserError>>, ParserError> {
104        match self.ty {
105            NodeType::String => {
106                let span = self.span.start + 1..self.span.end - 1;
107                Ok(StringPartsIter::new(&src[span.clone()], span.start))
108            }
109            NodeType::MultilineString => {
110                let span = self.span.start + 3..self.span.end - 3;
111                Ok(StringPartsIter::new_multiline(
112                    &src[span.clone()],
113                    span.start,
114                )?)
115            }
116            _ => Err(self.error(ParserErrorKind::InvalidType)),
117        }
118    }
119
120    /// Returns an iterator of value nodes if this node represents a tuple.
121    pub fn as_tuple(&self) -> Result<impl ExactSizeIterator<Item = &Node>, ParserError> {
122        self.ensure_type(NodeType::Tuple)?;
123        Ok(self.children.iter())
124    }
125
126    /// Returns an iterator of value nodes if this node represents a list.
127    pub fn as_list(&self) -> Result<impl ExactSizeIterator<Item = &Node>, ParserError> {
128        self.ensure_type(NodeType::List)?;
129        Ok(self.children.iter())
130    }
131
132    /// Returns an iterator of field name and value node pairs if this node
133    /// represents a record value.
134    pub fn as_record<'this, 'src>(
135        &'this self,
136        src: &'src str,
137    ) -> Result<impl ExactSizeIterator<Item = (&'src str, &'this Node)>, ParserError> {
138        self.ensure_type(NodeType::Record)?;
139        Ok(self
140            .children
141            .chunks(2)
142            .map(|chunk| (chunk[0].as_label(src).unwrap(), &chunk[1])))
143    }
144
145    /// Returns a variant label and optional payload if this node can represent
146    /// a variant value.
147    pub fn as_variant<'this, 'src>(
148        &'this self,
149        src: &'src str,
150    ) -> Result<(&'src str, Option<&'this Node>), ParserError> {
151        match self.ty {
152            NodeType::Label => Ok((self.as_label(src)?, None)),
153            NodeType::VariantWithPayload => {
154                let label = self.children[0].as_label(src)?;
155                let value = &self.children[1];
156                Ok((label, Some(value)))
157            }
158            _ => Err(self.error(ParserErrorKind::InvalidType)),
159        }
160    }
161
162    /// Returns an enum value label if this node represents a label.
163    pub fn as_enum<'src>(&self, src: &'src str) -> Result<&'src str, ParserError> {
164        self.as_label(src)
165    }
166
167    /// Returns an option value if this node represents an option.
168    pub fn as_option(&self) -> Result<Option<&Node>, ParserError> {
169        match self.ty {
170            NodeType::OptionSome => Ok(Some(&self.children[0])),
171            NodeType::OptionNone => Ok(None),
172            _ => Err(self.error(ParserErrorKind::InvalidType)),
173        }
174    }
175
176    /// Returns a result value with optional payload value if this node
177    /// represents a result.
178    pub fn as_result(&self) -> Result<Result<Option<&Node>, Option<&Node>>, ParserError> {
179        let payload = self.children.first();
180        match self.ty {
181            NodeType::ResultOk => Ok(Ok(payload)),
182            NodeType::ResultErr => Ok(Err(payload)),
183            _ => Err(self.error(ParserErrorKind::InvalidType)),
184        }
185    }
186
187    /// Returns an iterator of flag labels if this node represents flags.
188    pub fn as_flags<'this, 'src: 'this>(
189        &'this self,
190        src: &'src str,
191    ) -> Result<impl Iterator<Item = &'src str> + 'this, ParserError> {
192        self.ensure_type(NodeType::Flags)?;
193        Ok(self.children.iter().map(|node| {
194            debug_assert_eq!(node.ty, NodeType::Label);
195            node.slice(src)
196        }))
197    }
198
199    fn as_label<'src>(&self, src: &'src str) -> Result<&'src str, ParserError> {
200        self.ensure_type(NodeType::Label)?;
201        let label = self.slice(src);
202        let label = label.strip_prefix('%').unwrap_or(label);
203        Ok(label)
204    }
205
206    /// Converts this node into the given typed value from the given input source.
207    pub fn to_wasm_value<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
208        Ok(match ty.kind() {
209            WasmTypeKind::Bool => V::make_bool(self.as_bool()?),
210            WasmTypeKind::S8 => V::make_s8(self.as_number(src)?),
211            WasmTypeKind::S16 => V::make_s16(self.as_number(src)?),
212            WasmTypeKind::S32 => V::make_s32(self.as_number(src)?),
213            WasmTypeKind::S64 => V::make_s64(self.as_number(src)?),
214            WasmTypeKind::U8 => V::make_u8(self.as_number(src)?),
215            WasmTypeKind::U16 => V::make_u16(self.as_number(src)?),
216            WasmTypeKind::U32 => V::make_u32(self.as_number(src)?),
217            WasmTypeKind::U64 => V::make_u64(self.as_number(src)?),
218            WasmTypeKind::F32 => V::make_f32(self.as_number(src)?),
219            WasmTypeKind::F64 => V::make_f64(self.as_number(src)?),
220            WasmTypeKind::Char => V::make_char(self.as_char(src)?),
221            WasmTypeKind::String => V::make_string(self.as_str(src)?),
222            WasmTypeKind::List => self.to_wasm_list(ty, src)?,
223            WasmTypeKind::Record => self.to_wasm_record(ty, src)?,
224            WasmTypeKind::Tuple => self.to_wasm_tuple(ty, src)?,
225            WasmTypeKind::Variant => self.to_wasm_variant(ty, src)?,
226            WasmTypeKind::Enum => self.to_wasm_enum(ty, src)?,
227            WasmTypeKind::Option => self.to_wasm_option(ty, src)?,
228            WasmTypeKind::Result => self.to_wasm_result(ty, src)?,
229            WasmTypeKind::Flags => self.to_wasm_flags(ty, src)?,
230            other => {
231                return Err(
232                    self.wasm_value_error(WasmValueError::UnsupportedType(other.to_string()))
233                );
234            }
235        })
236    }
237
238    /// Converts this node into the given types.
239    /// See [`crate::untyped::UntypedFuncCall::to_wasm_params`].
240    pub fn to_wasm_params<'types, V: WasmValue + 'static>(
241        &self,
242        types: impl IntoIterator<Item = &'types V::Type>,
243        src: &str,
244    ) -> Result<Vec<V>, ParserError> {
245        let mut types = types.into_iter();
246        let mut values = self
247            .as_tuple()?
248            .map(|node| {
249                let ty = types.next().ok_or_else(|| {
250                    ParserError::with_detail(
251                        ParserErrorKind::InvalidParams,
252                        node.span().clone(),
253                        "more param(s) than expected",
254                    )
255                })?;
256                node.to_wasm_value::<V>(ty, src)
257            })
258            .collect::<Result<Vec<_>, _>>()?;
259        // Handle trailing optional fields
260        for ty in types {
261            if ty.kind() == WasmTypeKind::Option {
262                values.push(V::make_option(ty, None).map_err(|err| self.wasm_value_error(err))?);
263            } else {
264                return Err(ParserError::with_detail(
265                    ParserErrorKind::InvalidParams,
266                    self.span.end - 1..self.span.end,
267                    "missing required param(s)",
268                ));
269            }
270        }
271        Ok(values)
272    }
273
274    fn to_wasm_list<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
275        let element_type = ty.list_element_type().unwrap();
276        let elements = self
277            .as_list()?
278            .map(|node| node.to_wasm_value(&element_type, src))
279            .collect::<Result<Vec<_>, _>>()?;
280        V::make_list(ty, elements).map_err(|err| self.wasm_value_error(err))
281    }
282
283    fn to_wasm_record<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
284        let values = self.as_record(src)?.collect::<HashMap<_, _>>();
285        let record_fields = ty.record_fields().collect::<Vec<_>>();
286        let fields = record_fields
287            .iter()
288            .map(|(name, field_type)| {
289                let value = match (values.get(name.as_ref()), field_type.kind()) {
290                    (Some(node), _) => node.to_wasm_value(field_type, src)?,
291                    (None, WasmTypeKind::Option) => V::make_option(field_type, None)
292                        .map_err(|err| self.wasm_value_error(err))?,
293                    _ => {
294                        return Err(
295                            self.wasm_value_error(WasmValueError::MissingField(name.to_string()))
296                        );
297                    }
298                };
299                Ok((name.as_ref(), value))
300            })
301            .collect::<Result<Vec<_>, _>>()?;
302        V::make_record(ty, fields).map_err(|err| self.wasm_value_error(err))
303    }
304
305    fn to_wasm_tuple<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
306        let types = ty.tuple_element_types().collect::<Vec<_>>();
307        let values = self.as_tuple()?;
308        if types.len() != values.len() {
309            return Err(
310                self.wasm_value_error(WasmValueError::WrongNumberOfTupleValues {
311                    want: types.len(),
312                    got: values.len(),
313                }),
314            );
315        }
316        let values = ty
317            .tuple_element_types()
318            .zip(self.as_tuple()?)
319            .map(|(ty, node)| node.to_wasm_value(&ty, src))
320            .collect::<Result<Vec<_>, _>>()?;
321        V::make_tuple(ty, values).map_err(|err| self.wasm_value_error(err))
322    }
323
324    fn to_wasm_variant<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
325        let (label, payload) = self.as_variant(src)?;
326        let payload_type = ty
327            .variant_cases()
328            .find_map(|(case, payload)| (case == label).then_some(payload))
329            .ok_or_else(|| self.wasm_value_error(WasmValueError::UnknownCase(label.into())))?;
330        let payload_value = self.to_wasm_maybe_payload(label, &payload_type, payload, src)?;
331        V::make_variant(ty, label, payload_value).map_err(|err| self.wasm_value_error(err))
332    }
333
334    fn to_wasm_enum<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
335        V::make_enum(ty, self.as_enum(src)?).map_err(|err| self.wasm_value_error(err))
336    }
337
338    fn to_wasm_option<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
339        let payload_type = ty.option_some_type().unwrap();
340        let value = match self.ty {
341            NodeType::OptionSome => {
342                self.to_wasm_maybe_payload("some", &Some(payload_type), self.as_option()?, src)?
343            }
344            NodeType::OptionNone => {
345                self.to_wasm_maybe_payload("none", &None, self.as_option()?, src)?
346            }
347            _ if flattenable(payload_type.kind()) => Some(self.to_wasm_value(&payload_type, src)?),
348            _ => {
349                return Err(self.error(ParserErrorKind::InvalidType));
350            }
351        };
352        V::make_option(ty, value).map_err(|err| self.wasm_value_error(err))
353    }
354
355    fn to_wasm_result<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
356        let (ok_type, err_type) = ty.result_types().unwrap();
357        let value = match self.ty {
358            NodeType::ResultOk => {
359                Ok(self.to_wasm_maybe_payload("ok", &ok_type, self.as_result()?.unwrap(), src)?)
360            }
361            NodeType::ResultErr => Err(self.to_wasm_maybe_payload(
362                "err",
363                &err_type,
364                self.as_result()?.unwrap_err(),
365                src,
366            )?),
367            _ => match ok_type {
368                Some(ty) if flattenable(ty.kind()) => Ok(Some(self.to_wasm_value(&ty, src)?)),
369                _ => return Err(self.error(ParserErrorKind::InvalidType)),
370            },
371        };
372        V::make_result(ty, value).map_err(|err| self.wasm_value_error(err))
373    }
374
375    fn to_wasm_flags<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
376        V::make_flags(ty, self.as_flags(src)?).map_err(|err| self.wasm_value_error(err))
377    }
378
379    fn to_wasm_maybe_payload<V: WasmValue>(
380        &self,
381        case: &str,
382        payload_type: &Option<V::Type>,
383        payload: Option<&Node>,
384        src: &str,
385    ) -> Result<Option<V>, ParserError> {
386        match (payload_type.as_ref(), payload) {
387            (Some(ty), Some(node)) => Ok(Some(node.to_wasm_value(ty, src)?)),
388            (None, None) => Ok(None),
389            (Some(_), None) => {
390                Err(self.wasm_value_error(WasmValueError::MissingPayload(case.into())))
391            }
392            (None, Some(_)) => {
393                Err(self.wasm_value_error(WasmValueError::UnexpectedPayload(case.into())))
394            }
395        }
396    }
397
398    fn wasm_value_error(&self, err: WasmValueError) -> ParserError {
399        ParserError::with_source(ParserErrorKind::WasmValueError, self.span(), err)
400    }
401
402    pub(crate) fn slice<'src>(&self, src: &'src str) -> &'src str {
403        &src[self.span()]
404    }
405
406    fn ensure_type(&self, ty: NodeType) -> Result<(), ParserError> {
407        if self.ty == ty {
408            Ok(())
409        } else {
410            Err(self.error(ParserErrorKind::InvalidType))
411        }
412    }
413
414    fn error(&self, kind: ParserErrorKind) -> ParserError {
415        ParserError::new(kind, self.span())
416    }
417}
418
419fn flattenable(kind: WasmTypeKind) -> bool {
420    // TODO: Consider wither to allow flattening an option in a result or vice-versa.
421    !matches!(kind, WasmTypeKind::Option | WasmTypeKind::Result)
422}
423
424/// The type of a WAVE AST [`Node`].
425#[derive(Clone, Copy, Debug, PartialEq)]
426pub enum NodeType {
427    /// Boolean `true`
428    BoolTrue,
429    /// Boolean `false`
430    BoolFalse,
431    /// Number
432    /// May be an integer or float, including `nan`, `inf`, `-inf`
433    Number,
434    /// Char
435    /// Span includes delimiters.
436    Char,
437    /// String
438    /// Span includes delimiters.
439    String,
440    /// Multiline String
441    /// Span includes delimiters.
442    MultilineString,
443    /// Tuple
444    /// Child nodes are the tuple values.
445    Tuple,
446    /// List
447    /// Child nodes are the list values.
448    List,
449    /// Record
450    /// Child nodes are field Label, value pairs, e.g.
451    /// `[<field 1>, <value 1>, <field 2>, <value 2>, ...]`
452    Record,
453    /// Label
454    /// In value position may represent an enum value or variant case (without payload).
455    Label,
456    /// Variant case with payload
457    /// Child nodes are variant case Label and payload value.
458    VariantWithPayload,
459    /// Option `some`
460    /// Child node is the payload value.
461    OptionSome,
462    /// Option `none`
463    OptionNone,
464    /// Result `ok`
465    /// Has zero or one child node for the payload value.
466    ResultOk,
467    /// Result `err`
468    /// Has zero or one child node for the payload value.
469    ResultErr,
470    /// Flags
471    /// Child nodes are flag Labels.
472    Flags,
473}