Skip to main content

wit_parser/
live.rs

1use crate::{
2    Function, FunctionKind, IndexSet, InterfaceId, Resolve, Type, TypeDef, TypeDefKind, TypeId,
3    WorldId, WorldItem,
4};
5
6#[derive(Default)]
7pub struct LiveTypes {
8    set: IndexSet<TypeId>,
9}
10
11impl LiveTypes {
12    pub fn iter(&self) -> impl Iterator<Item = TypeId> + '_ {
13        self.set.iter().copied()
14    }
15
16    pub fn len(&self) -> usize {
17        self.set.len()
18    }
19
20    pub fn contains(&self, id: TypeId) -> bool {
21        self.set.contains(&id)
22    }
23
24    pub fn add_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
25        self.visit_interface(resolve, iface);
26    }
27
28    pub fn add_world(&mut self, resolve: &Resolve, world: WorldId) {
29        self.visit_world(resolve, world);
30    }
31
32    pub fn add_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
33        self.visit_world_item(resolve, item);
34    }
35
36    pub fn add_func(&mut self, resolve: &Resolve, func: &Function) {
37        self.visit_func(resolve, func);
38    }
39
40    pub fn add_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
41        self.visit_type_id(resolve, ty);
42    }
43
44    pub fn add_type(&mut self, resolve: &Resolve, ty: &Type) {
45        self.visit_type(resolve, ty);
46    }
47}
48
49impl TypeIdVisitor for LiveTypes {
50    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
51        !self.set.contains(&id)
52    }
53
54    fn after_visit_type_id(&mut self, id: TypeId) {
55        assert!(self.set.insert(id));
56    }
57}
58
59/// Helper trait to walk the structure of a type and visit all `TypeId`s that
60/// it refers to, possibly transitively.
61pub trait TypeIdVisitor {
62    /// Callback invoked just before a type is visited.
63    ///
64    /// If this function returns `false` the type is not visited, otherwise it's
65    /// recursed into.
66    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
67        let _ = id;
68        true
69    }
70
71    /// Callback invoked once a type is finished being visited.
72    fn after_visit_type_id(&mut self, id: TypeId) {
73        let _ = id;
74    }
75
76    fn visit_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
77        let iface = &resolve.interfaces[iface];
78        for (_, id) in iface.types.iter() {
79            self.visit_type_id(resolve, *id);
80        }
81        for (_, func) in iface.functions.iter() {
82            self.visit_func(resolve, func);
83        }
84    }
85
86    fn visit_world(&mut self, resolve: &Resolve, world: WorldId) {
87        let world = &resolve.worlds[world];
88        for (_, item) in world.imports.iter().chain(world.exports.iter()) {
89            self.visit_world_item(resolve, item);
90        }
91    }
92
93    fn visit_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
94        match item {
95            WorldItem::Interface { id, .. } => self.visit_interface(resolve, *id),
96            WorldItem::Function(f) => self.visit_func(resolve, f),
97            WorldItem::Type { id, .. } => self.visit_type_id(resolve, *id),
98        }
99    }
100
101    fn visit_func(&mut self, resolve: &Resolve, func: &Function) {
102        match func.kind {
103            // This resource is live as it's attached to a static method but
104            // it's not guaranteed to be present in either params or results, so
105            // be sure to attach it here.
106            FunctionKind::Static(id) | FunctionKind::AsyncStatic(id) => {
107                self.visit_type_id(resolve, id)
108            }
109
110            // The resource these are attached to is in the params/results, so
111            // no need to re-add it here.
112            FunctionKind::Method(_)
113            | FunctionKind::AsyncMethod(_)
114            | FunctionKind::Constructor(_) => {}
115
116            FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {}
117        }
118
119        for param in func.params.iter() {
120            self.visit_type(resolve, &param.ty);
121        }
122        if let Some(ty) = &func.result {
123            self.visit_type(resolve, ty);
124        }
125    }
126
127    fn visit_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
128        if self.before_visit_type_id(ty) {
129            self.visit_type_def(resolve, &resolve.types[ty]);
130            self.after_visit_type_id(ty);
131        }
132    }
133
134    fn visit_type_def(&mut self, resolve: &Resolve, ty: &TypeDef) {
135        match &ty.kind {
136            TypeDefKind::Type(t)
137            | TypeDefKind::List(t)
138            | TypeDefKind::FixedLengthList(t, ..)
139            | TypeDefKind::Option(t)
140            | TypeDefKind::Future(Some(t))
141            | TypeDefKind::Stream(Some(t)) => self.visit_type(resolve, t),
142            TypeDefKind::Map(k, v) => {
143                self.visit_type(resolve, k);
144                self.visit_type(resolve, v);
145            }
146            TypeDefKind::Handle(handle) => match handle {
147                crate::Handle::Own(ty) => self.visit_type_id(resolve, *ty),
148                crate::Handle::Borrow(ty) => self.visit_type_id(resolve, *ty),
149            },
150            TypeDefKind::Resource => {}
151            TypeDefKind::Record(r) => {
152                for field in r.fields.iter() {
153                    self.visit_type(resolve, &field.ty);
154                }
155            }
156            TypeDefKind::Tuple(r) => {
157                for ty in r.types.iter() {
158                    self.visit_type(resolve, ty);
159                }
160            }
161            TypeDefKind::Variant(v) => {
162                for case in v.cases.iter() {
163                    if let Some(ty) = &case.ty {
164                        self.visit_type(resolve, ty);
165                    }
166                }
167            }
168            TypeDefKind::Result(r) => {
169                if let Some(ty) = &r.ok {
170                    self.visit_type(resolve, ty);
171                }
172                if let Some(ty) = &r.err {
173                    self.visit_type(resolve, ty);
174                }
175            }
176            TypeDefKind::Flags(_)
177            | TypeDefKind::Enum(_)
178            | TypeDefKind::Future(None)
179            | TypeDefKind::Stream(None) => {}
180            TypeDefKind::Unknown => unreachable!(),
181        }
182    }
183
184    fn visit_type(&mut self, resolve: &Resolve, ty: &Type) {
185        match ty {
186            Type::Id(id) => self.visit_type_id(resolve, *id),
187            _ => {}
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::{LiveTypes, Resolve};
195    use alloc::string::String;
196    use alloc::vec::Vec;
197
198    fn live(wit: &str, ty: &str) -> Vec<String> {
199        let mut resolve = Resolve::default();
200        resolve.push_str("test.wit", wit).unwrap();
201        let (_, interface) = resolve.interfaces.iter().next_back().unwrap();
202        let ty = interface.types[ty];
203        let mut live = LiveTypes::default();
204        live.add_type_id(&resolve, ty);
205
206        live.iter()
207            .filter_map(|ty| resolve.types[ty].name.clone())
208            .collect()
209    }
210
211    #[test]
212    fn no_deps() {
213        let types = live(
214            "
215                package foo:bar;
216
217                interface foo {
218                    type t = u32;
219                }
220            ",
221            "t",
222        );
223        assert_eq!(types, ["t"]);
224    }
225
226    #[test]
227    fn one_dep() {
228        let types = live(
229            "
230                package foo:bar;
231
232                interface foo {
233                    type t = u32;
234                    type u = t;
235                }
236            ",
237            "u",
238        );
239        assert_eq!(types, ["t", "u"]);
240    }
241
242    #[test]
243    fn chain() {
244        let types = live(
245            "
246                package foo:bar;
247
248                interface foo {
249                    resource t1;
250                    record t2 {
251                        x: t1,
252                    }
253                    variant t3 {
254                        x(t2),
255                    }
256                    flags t4 { a }
257                    enum t5 { a }
258                    type t6 = tuple<t5, t4, t3>;
259                }
260            ",
261            "t6",
262        );
263        assert_eq!(types, ["t5", "t4", "t1", "t2", "t3", "t6"]);
264    }
265}