wit_parser/
live.rs

1use crate::{
2    Function, FunctionKind, InterfaceId, Resolve, Type, TypeDef, TypeDefKind, TypeId, WorldId,
3    WorldItem,
4};
5use indexmap::IndexSet;
6
7#[derive(Default)]
8pub struct LiveTypes {
9    set: IndexSet<TypeId>,
10}
11
12impl LiveTypes {
13    pub fn iter(&self) -> impl Iterator<Item = TypeId> + '_ {
14        self.set.iter().copied()
15    }
16
17    pub fn len(&self) -> usize {
18        self.set.len()
19    }
20
21    pub fn contains(&self, id: TypeId) -> bool {
22        self.set.contains(&id)
23    }
24
25    pub fn add_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
26        self.visit_interface(resolve, iface);
27    }
28
29    pub fn add_world(&mut self, resolve: &Resolve, world: WorldId) {
30        self.visit_world(resolve, world);
31    }
32
33    pub fn add_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
34        self.visit_world_item(resolve, item);
35    }
36
37    pub fn add_func(&mut self, resolve: &Resolve, func: &Function) {
38        self.visit_func(resolve, func);
39    }
40
41    pub fn add_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
42        self.visit_type_id(resolve, ty);
43    }
44
45    pub fn add_type(&mut self, resolve: &Resolve, ty: &Type) {
46        self.visit_type(resolve, ty);
47    }
48}
49
50impl TypeIdVisitor for LiveTypes {
51    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
52        !self.set.contains(&id)
53    }
54
55    fn after_visit_type_id(&mut self, id: TypeId) {
56        assert!(self.set.insert(id));
57    }
58}
59
60/// Helper trait to walk the structure of a type and visit all `TypeId`s that
61/// it refers to, possibly transitively.
62pub trait TypeIdVisitor {
63    /// Callback invoked just before a type is visited.
64    ///
65    /// If this function returns `false` the type is not visited, otherwise it's
66    /// recursed into.
67    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
68        let _ = id;
69        true
70    }
71
72    /// Callback invoked once a type is finished being visited.
73    fn after_visit_type_id(&mut self, id: TypeId) {
74        let _ = id;
75    }
76
77    fn visit_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
78        let iface = &resolve.interfaces[iface];
79        for (_, id) in iface.types.iter() {
80            self.visit_type_id(resolve, *id);
81        }
82        for (_, func) in iface.functions.iter() {
83            self.visit_func(resolve, func);
84        }
85    }
86
87    fn visit_world(&mut self, resolve: &Resolve, world: WorldId) {
88        let world = &resolve.worlds[world];
89        for (_, item) in world.imports.iter().chain(world.exports.iter()) {
90            self.visit_world_item(resolve, item);
91        }
92    }
93
94    fn visit_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
95        match item {
96            WorldItem::Interface { id, .. } => self.visit_interface(resolve, *id),
97            WorldItem::Function(f) => self.visit_func(resolve, f),
98            WorldItem::Type(t) => self.visit_type_id(resolve, *t),
99        }
100    }
101
102    fn visit_func(&mut self, resolve: &Resolve, func: &Function) {
103        match func.kind {
104            // This resource is live as it's attached to a static method but
105            // it's not guaranteed to be present in either params or results, so
106            // be sure to attach it here.
107            FunctionKind::Static(id) | FunctionKind::AsyncStatic(id) => {
108                self.visit_type_id(resolve, id)
109            }
110
111            // The resource these are attached to is in the params/results, so
112            // no need to re-add it here.
113            FunctionKind::Method(_)
114            | FunctionKind::AsyncMethod(_)
115            | FunctionKind::Constructor(_) => {}
116
117            FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {}
118        }
119
120        for (_, ty) in func.params.iter() {
121            self.visit_type(resolve, ty);
122        }
123        if let Some(ty) = &func.result {
124            self.visit_type(resolve, ty);
125        }
126    }
127
128    fn visit_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
129        if self.before_visit_type_id(ty) {
130            self.visit_type_def(resolve, &resolve.types[ty]);
131            self.after_visit_type_id(ty);
132        }
133    }
134
135    fn visit_type_def(&mut self, resolve: &Resolve, ty: &TypeDef) {
136        match &ty.kind {
137            TypeDefKind::Type(t)
138            | TypeDefKind::List(t)
139            | TypeDefKind::FixedSizeList(t, ..)
140            | TypeDefKind::Option(t)
141            | TypeDefKind::Future(Some(t))
142            | TypeDefKind::Stream(Some(t)) => self.visit_type(resolve, t),
143            TypeDefKind::Map(k, v) => {
144                self.visit_type(resolve, k);
145                self.visit_type(resolve, v);
146            }
147            TypeDefKind::Handle(handle) => match handle {
148                crate::Handle::Own(ty) => self.visit_type_id(resolve, *ty),
149                crate::Handle::Borrow(ty) => self.visit_type_id(resolve, *ty),
150            },
151            TypeDefKind::Resource => {}
152            TypeDefKind::Record(r) => {
153                for field in r.fields.iter() {
154                    self.visit_type(resolve, &field.ty);
155                }
156            }
157            TypeDefKind::Tuple(r) => {
158                for ty in r.types.iter() {
159                    self.visit_type(resolve, ty);
160                }
161            }
162            TypeDefKind::Variant(v) => {
163                for case in v.cases.iter() {
164                    if let Some(ty) = &case.ty {
165                        self.visit_type(resolve, ty);
166                    }
167                }
168            }
169            TypeDefKind::Result(r) => {
170                if let Some(ty) = &r.ok {
171                    self.visit_type(resolve, ty);
172                }
173                if let Some(ty) = &r.err {
174                    self.visit_type(resolve, ty);
175                }
176            }
177            TypeDefKind::Flags(_)
178            | TypeDefKind::Enum(_)
179            | TypeDefKind::Future(None)
180            | TypeDefKind::Stream(None) => {}
181            TypeDefKind::Unknown => unreachable!(),
182        }
183    }
184
185    fn visit_type(&mut self, resolve: &Resolve, ty: &Type) {
186        match ty {
187            Type::Id(id) => self.visit_type_id(resolve, *id),
188            _ => {}
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::{LiveTypes, Resolve};
196
197    fn live(wit: &str, ty: &str) -> Vec<String> {
198        let mut resolve = Resolve::default();
199        resolve.push_str("test.wit", wit).unwrap();
200        let (_, interface) = resolve.interfaces.iter().next_back().unwrap();
201        let ty = interface.types[ty];
202        let mut live = LiveTypes::default();
203        live.add_type_id(&resolve, ty);
204
205        live.iter()
206            .filter_map(|ty| resolve.types[ty].name.clone())
207            .collect()
208    }
209
210    #[test]
211    fn no_deps() {
212        let types = live(
213            "
214                package foo:bar;
215
216                interface foo {
217                    type t = u32;
218                }
219            ",
220            "t",
221        );
222        assert_eq!(types, ["t"]);
223    }
224
225    #[test]
226    fn one_dep() {
227        let types = live(
228            "
229                package foo:bar;
230
231                interface foo {
232                    type t = u32;
233                    type u = t;
234                }
235            ",
236            "u",
237        );
238        assert_eq!(types, ["t", "u"]);
239    }
240
241    #[test]
242    fn chain() {
243        let types = live(
244            "
245                package foo:bar;
246
247                interface foo {
248                    resource t1;
249                    record t2 {
250                        x: t1,
251                    }
252                    variant t3 {
253                        x(t2),
254                    }
255                    flags t4 { a }
256                    enum t5 { a }
257                    type t6 = tuple<t5, t4, t3>;
258                }
259            ",
260            "t6",
261        );
262        assert_eq!(types, ["t5", "t4", "t1", "t2", "t3", "t6"]);
263    }
264}