1use wit_parser::{
2 Enum, Flags, Function, Record, Resolve, Result_, Tuple, Type, TypeDefKind, TypeId, Variant,
3};
4
5use crate::{value, wasm::WasmValueError};
6
7pub fn resolve_wit_type(resolve: &Resolve, type_id: TypeId) -> Result<value::Type, WasmValueError> {
11 TypeResolver { resolve }.resolve_type_id(type_id)
12}
13
14pub fn resolve_wit_func_type(
18 resolve: &Resolve,
19 function: &Function,
20) -> Result<value::FuncType, WasmValueError> {
21 let resolver = TypeResolver { resolve };
22 let params = resolver.resolve_params(&function.params)?;
23 let results = match &function.result {
24 Some(ty) => vec![("".into(), resolver.resolve_type(*ty)?)],
25 None => Vec::new(),
26 };
27 value::FuncType::new(params, results)
28}
29
30struct TypeResolver<'a> {
31 resolve: &'a Resolve,
32}
33
34type ValueResult = Result<value::Type, WasmValueError>;
35
36impl<'a> TypeResolver<'a> {
37 fn resolve_type_id(&self, type_id: TypeId) -> ValueResult {
38 self.resolve(&self.resolve.types.get(type_id).unwrap().kind)
39 }
40
41 fn resolve_type(&self, ty: Type) -> ValueResult {
42 self.resolve(&TypeDefKind::Type(ty))
43 }
44
45 fn resolve_params(
46 &self,
47 params: &[(String, Type)],
48 ) -> Result<Vec<(String, value::Type)>, WasmValueError> {
49 params
50 .iter()
51 .map(|(name, ty)| {
52 let ty = self.resolve_type(*ty)?;
53 Ok((name.clone(), ty))
54 })
55 .collect()
56 }
57
58 fn resolve(&self, mut kind: &'a TypeDefKind) -> ValueResult {
59 while let &TypeDefKind::Type(Type::Id(id)) = kind {
61 kind = &self.resolve.types.get(id).unwrap().kind;
62 }
63
64 match kind {
65 TypeDefKind::Record(record) => self.resolve_record(record),
66 TypeDefKind::Flags(flags) => self.resolve_flags(flags),
67 TypeDefKind::Tuple(tuple) => self.resolve_tuple(tuple),
68 TypeDefKind::Variant(variant) => self.resolve_variant(variant),
69 TypeDefKind::Enum(enum_) => self.resolve_enum(enum_),
70 TypeDefKind::Option(some_type) => self.resolve_option(some_type),
71 TypeDefKind::Result(result) => self.resolve_result(result),
72 TypeDefKind::List(element_type) => self.resolve_list(element_type),
73 TypeDefKind::FixedSizeList(element_type, elements) => {
74 self.resolve_fixed_size_list(element_type, *elements)
75 }
76 TypeDefKind::Type(Type::Bool) => Ok(value::Type::BOOL),
77 TypeDefKind::Type(Type::U8) => Ok(value::Type::U8),
78 TypeDefKind::Type(Type::U16) => Ok(value::Type::U16),
79 TypeDefKind::Type(Type::U32) => Ok(value::Type::U32),
80 TypeDefKind::Type(Type::U64) => Ok(value::Type::U64),
81 TypeDefKind::Type(Type::S8) => Ok(value::Type::S8),
82 TypeDefKind::Type(Type::S16) => Ok(value::Type::S16),
83 TypeDefKind::Type(Type::S32) => Ok(value::Type::S32),
84 TypeDefKind::Type(Type::S64) => Ok(value::Type::S64),
85 TypeDefKind::Type(Type::F32) => Ok(value::Type::F32),
86 TypeDefKind::Type(Type::F64) => Ok(value::Type::F64),
87 TypeDefKind::Type(Type::Char) => Ok(value::Type::CHAR),
88 TypeDefKind::Type(Type::String) => Ok(value::Type::STRING),
89 TypeDefKind::Type(Type::Id(_)) => unreachable!(),
90 other => Err(WasmValueError::UnsupportedType(other.as_str().into())),
91 }
92 }
93
94 fn resolve_record(&self, record: &Record) -> ValueResult {
95 let fields = record
96 .fields
97 .iter()
98 .map(|f| Ok((f.name.as_str(), self.resolve_type(f.ty)?)))
99 .collect::<Result<Vec<_>, _>>()?;
100 Ok(value::Type::record(fields).unwrap())
101 }
102
103 fn resolve_flags(&self, flags: &Flags) -> ValueResult {
104 let names = flags.flags.iter().map(|f| f.name.as_str());
105 Ok(value::Type::flags(names).unwrap())
106 }
107
108 fn resolve_tuple(&self, tuple: &Tuple) -> ValueResult {
109 let types = tuple
110 .types
111 .iter()
112 .map(|ty| self.resolve_type(*ty))
113 .collect::<Result<Vec<_>, _>>()?;
114 Ok(value::Type::tuple(types).unwrap())
115 }
116
117 fn resolve_variant(&self, variant: &Variant) -> ValueResult {
118 let cases = variant
119 .cases
120 .iter()
121 .map(|case| {
122 Ok((
123 case.name.as_str(),
124 case.ty.map(|ty| self.resolve_type(ty)).transpose()?,
125 ))
126 })
127 .collect::<Result<Vec<_>, _>>()?;
128 Ok(value::Type::variant(cases).unwrap())
129 }
130
131 fn resolve_enum(&self, enum_: &Enum) -> ValueResult {
132 let cases = enum_.cases.iter().map(|c| c.name.as_str());
133 Ok(value::Type::enum_ty(cases).unwrap())
134 }
135
136 fn resolve_option(&self, some_type: &Type) -> ValueResult {
137 let some = self.resolve_type(*some_type)?;
138 Ok(value::Type::option(some))
139 }
140
141 fn resolve_result(&self, result: &Result_) -> ValueResult {
142 let ok = result.ok.map(|ty| self.resolve_type(ty)).transpose()?;
143 let err = result.err.map(|ty| self.resolve_type(ty)).transpose()?;
144 Ok(value::Type::result(ok, err))
145 }
146
147 fn resolve_list(&self, element_type: &Type) -> ValueResult {
148 let element_type = self.resolve_type(*element_type)?;
149 Ok(value::Type::list(element_type))
150 }
151
152 fn resolve_fixed_size_list(&self, element_type: &Type, elements: u32) -> ValueResult {
153 let element_type = self.resolve_type(*element_type)?;
154 Ok(value::Type::fixed_size_list(element_type, elements))
155 }
156}
157
158#[cfg(test)]
159mod tests {
160
161 use super::*;
162
163 #[test]
164 fn resolve_wit_type_smoke_test() {
165 let mut resolve = Resolve::new();
166 resolve
167 .push_str(
168 "test.wit",
169 "
170package test:types;
171interface types {
172 type uint8 = u8;
173}
174 ",
175 )
176 .unwrap();
177
178 let (type_id, _) = resolve.types.iter().next().unwrap();
179 let ty = resolve_wit_type(&resolve, type_id).unwrap();
180 assert_eq!(ty, value::Type::U8);
181 }
182
183 #[test]
184 fn resolve_wit_func_type_smoke_test() {
185 let mut resolve = Resolve::new();
186 resolve
187 .push_str(
188 "test.wit",
189 r#"
190package test:types;
191interface types {
192 type uint8 = u8;
193 no-results: func(a: uint8, b: string);
194 one-result: func(c: uint8, d: string) -> uint8;
195 named-results: func(e: uint8, f: string) -> tuple<u8, string>;
196}
197 "#,
198 )
199 .unwrap();
200
201 for (func_name, expected_display) in [
202 ("no-results", "func(a: u8, b: string)"),
203 ("one-result", "func(c: u8, d: string) -> u8"),
204 (
205 "named-results",
206 "func(e: u8, f: string) -> tuple<u8, string>",
207 ),
208 ] {
209 let function = resolve
210 .interfaces
211 .iter()
212 .flat_map(|(_, i)| &i.functions)
213 .find_map(|(name, function)| (name == func_name).then_some(function))
214 .unwrap();
215 let ty = resolve_wit_func_type(&resolve, function).unwrap();
216 assert_eq!(ty.to_string(), expected_display, "for {function:?}");
217 }
218 }
219}