wasm_wave/value/
wit.rs

1use wit_parser::{
2    Enum, Flags, Function, Record, Resolve, Result_, Tuple, Type, TypeDefKind, TypeId, Variant,
3};
4
5use crate::{value, wasm::WasmValueError};
6
7/// Resolves a [`value::Type`] from the given [`wit_parser::Resolve`] and [`TypeId`].
8/// # Panics
9/// Panics if `type_id` is not valid in `resolve`.
10pub fn resolve_wit_type(resolve: &Resolve, type_id: TypeId) -> Result<value::Type, WasmValueError> {
11    TypeResolver { resolve }.resolve_type_id(type_id)
12}
13
14/// Resolves a [`value::FuncType`] from the given [`wit_parser::Resolve`] and [`Function`].
15/// # Panics
16/// Panics if `function`'s types are not valid in `resolve`.
17pub fn resolve_wit_func_type(
18    resolve: &Resolve,
19    function: &Function,
20) -> Result<value::FuncType, WasmValueError> {
21    let resolver = TypeResolver { resolve };
22    let params = resolver.resolve_params(&function.params)?;
23    let results = match &function.result {
24        Some(ty) => vec![("".into(), resolver.resolve_type(*ty)?)],
25        None => Vec::new(),
26    };
27    value::FuncType::new(params, results)
28}
29
30struct TypeResolver<'a> {
31    resolve: &'a Resolve,
32}
33
34type ValueResult = Result<value::Type, WasmValueError>;
35
36impl<'a> TypeResolver<'a> {
37    fn resolve_type_id(&self, type_id: TypeId) -> ValueResult {
38        self.resolve(&self.resolve.types.get(type_id).unwrap().kind)
39    }
40
41    fn resolve_type(&self, ty: Type) -> ValueResult {
42        self.resolve(&TypeDefKind::Type(ty))
43    }
44
45    fn resolve_params(
46        &self,
47        params: &[(String, Type)],
48    ) -> Result<Vec<(String, value::Type)>, WasmValueError> {
49        params
50            .iter()
51            .map(|(name, ty)| {
52                let ty = self.resolve_type(*ty)?;
53                Ok((name.clone(), ty))
54            })
55            .collect()
56    }
57
58    fn resolve(&self, mut kind: &'a TypeDefKind) -> ValueResult {
59        // Recursively resolve any type defs.
60        while let &TypeDefKind::Type(Type::Id(id)) = kind {
61            kind = &self.resolve.types.get(id).unwrap().kind;
62        }
63
64        match kind {
65            TypeDefKind::Record(record) => self.resolve_record(record),
66            TypeDefKind::Flags(flags) => self.resolve_flags(flags),
67            TypeDefKind::Tuple(tuple) => self.resolve_tuple(tuple),
68            TypeDefKind::Variant(variant) => self.resolve_variant(variant),
69            TypeDefKind::Enum(enum_) => self.resolve_enum(enum_),
70            TypeDefKind::Option(some_type) => self.resolve_option(some_type),
71            TypeDefKind::Result(result) => self.resolve_result(result),
72            TypeDefKind::List(element_type) => self.resolve_list(element_type),
73            TypeDefKind::FixedSizeList(element_type, elements) => {
74                self.resolve_fixed_size_list(element_type, *elements)
75            }
76            TypeDefKind::Type(Type::Bool) => Ok(value::Type::BOOL),
77            TypeDefKind::Type(Type::U8) => Ok(value::Type::U8),
78            TypeDefKind::Type(Type::U16) => Ok(value::Type::U16),
79            TypeDefKind::Type(Type::U32) => Ok(value::Type::U32),
80            TypeDefKind::Type(Type::U64) => Ok(value::Type::U64),
81            TypeDefKind::Type(Type::S8) => Ok(value::Type::S8),
82            TypeDefKind::Type(Type::S16) => Ok(value::Type::S16),
83            TypeDefKind::Type(Type::S32) => Ok(value::Type::S32),
84            TypeDefKind::Type(Type::S64) => Ok(value::Type::S64),
85            TypeDefKind::Type(Type::F32) => Ok(value::Type::F32),
86            TypeDefKind::Type(Type::F64) => Ok(value::Type::F64),
87            TypeDefKind::Type(Type::Char) => Ok(value::Type::CHAR),
88            TypeDefKind::Type(Type::String) => Ok(value::Type::STRING),
89            TypeDefKind::Type(Type::Id(_)) => unreachable!(),
90            other => Err(WasmValueError::UnsupportedType(other.as_str().into())),
91        }
92    }
93
94    fn resolve_record(&self, record: &Record) -> ValueResult {
95        let fields = record
96            .fields
97            .iter()
98            .map(|f| Ok((f.name.as_str(), self.resolve_type(f.ty)?)))
99            .collect::<Result<Vec<_>, _>>()?;
100        Ok(value::Type::record(fields).unwrap())
101    }
102
103    fn resolve_flags(&self, flags: &Flags) -> ValueResult {
104        let names = flags.flags.iter().map(|f| f.name.as_str());
105        Ok(value::Type::flags(names).unwrap())
106    }
107
108    fn resolve_tuple(&self, tuple: &Tuple) -> ValueResult {
109        let types = tuple
110            .types
111            .iter()
112            .map(|ty| self.resolve_type(*ty))
113            .collect::<Result<Vec<_>, _>>()?;
114        Ok(value::Type::tuple(types).unwrap())
115    }
116
117    fn resolve_variant(&self, variant: &Variant) -> ValueResult {
118        let cases = variant
119            .cases
120            .iter()
121            .map(|case| {
122                Ok((
123                    case.name.as_str(),
124                    case.ty.map(|ty| self.resolve_type(ty)).transpose()?,
125                ))
126            })
127            .collect::<Result<Vec<_>, _>>()?;
128        Ok(value::Type::variant(cases).unwrap())
129    }
130
131    fn resolve_enum(&self, enum_: &Enum) -> ValueResult {
132        let cases = enum_.cases.iter().map(|c| c.name.as_str());
133        Ok(value::Type::enum_ty(cases).unwrap())
134    }
135
136    fn resolve_option(&self, some_type: &Type) -> ValueResult {
137        let some = self.resolve_type(*some_type)?;
138        Ok(value::Type::option(some))
139    }
140
141    fn resolve_result(&self, result: &Result_) -> ValueResult {
142        let ok = result.ok.map(|ty| self.resolve_type(ty)).transpose()?;
143        let err = result.err.map(|ty| self.resolve_type(ty)).transpose()?;
144        Ok(value::Type::result(ok, err))
145    }
146
147    fn resolve_list(&self, element_type: &Type) -> ValueResult {
148        let element_type = self.resolve_type(*element_type)?;
149        Ok(value::Type::list(element_type))
150    }
151
152    fn resolve_fixed_size_list(&self, element_type: &Type, elements: u32) -> ValueResult {
153        let element_type = self.resolve_type(*element_type)?;
154        Ok(value::Type::fixed_size_list(element_type, elements))
155    }
156}
157
158#[cfg(test)]
159mod tests {
160
161    use super::*;
162
163    #[test]
164    fn resolve_wit_type_smoke_test() {
165        let mut resolve = Resolve::new();
166        resolve
167            .push_str(
168                "test.wit",
169                "
170package test:types;
171interface types {
172    type uint8 = u8;
173}
174                ",
175            )
176            .unwrap();
177
178        let (type_id, _) = resolve.types.iter().next().unwrap();
179        let ty = resolve_wit_type(&resolve, type_id).unwrap();
180        assert_eq!(ty, value::Type::U8);
181    }
182
183    #[test]
184    fn resolve_wit_func_type_smoke_test() {
185        let mut resolve = Resolve::new();
186        resolve
187            .push_str(
188                "test.wit",
189                r#"
190package test:types;
191interface types {
192    type uint8 = u8;
193    no-results: func(a: uint8, b: string);
194    one-result: func(c: uint8, d: string) -> uint8;
195    named-results: func(e: uint8, f: string) -> tuple<u8, string>;
196}
197                "#,
198            )
199            .unwrap();
200
201        for (func_name, expected_display) in [
202            ("no-results", "func(a: u8, b: string)"),
203            ("one-result", "func(c: u8, d: string) -> u8"),
204            (
205                "named-results",
206                "func(e: u8, f: string) -> tuple<u8, string>",
207            ),
208        ] {
209            let function = resolve
210                .interfaces
211                .iter()
212                .flat_map(|(_, i)| &i.functions)
213                .find_map(|(name, function)| (name == func_name).then_some(function))
214                .unwrap();
215            let ty = resolve_wit_func_type(&resolve, function).unwrap();
216            assert_eq!(ty.to_string(), expected_display, "for {function:?}");
217        }
218    }
219}