1use crate::{
2 Function, FunctionKind, InterfaceId, Resolve, Type, TypeDef, TypeDefKind, TypeId, WorldId,
3 WorldItem,
4};
5use indexmap::IndexSet;
6
7#[derive(Default)]
8pub struct LiveTypes {
9 set: IndexSet<TypeId>,
10}
11
12impl LiveTypes {
13 pub fn iter(&self) -> impl Iterator<Item = TypeId> + '_ {
14 self.set.iter().copied()
15 }
16
17 pub fn len(&self) -> usize {
18 self.set.len()
19 }
20
21 pub fn add_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
22 self.visit_interface(resolve, iface);
23 }
24
25 pub fn add_world(&mut self, resolve: &Resolve, world: WorldId) {
26 self.visit_world(resolve, world);
27 }
28
29 pub fn add_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
30 self.visit_world_item(resolve, item);
31 }
32
33 pub fn add_func(&mut self, resolve: &Resolve, func: &Function) {
34 self.visit_func(resolve, func);
35 }
36
37 pub fn add_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
38 self.visit_type_id(resolve, ty);
39 }
40
41 pub fn add_type(&mut self, resolve: &Resolve, ty: &Type) {
42 self.visit_type(resolve, ty);
43 }
44}
45
46impl TypeIdVisitor for LiveTypes {
47 fn before_visit_type_id(&mut self, id: TypeId) -> bool {
48 !self.set.contains(&id)
49 }
50
51 fn after_visit_type_id(&mut self, id: TypeId) {
52 assert!(self.set.insert(id));
53 }
54}
55
56/// Helper trait to walk the structure of a type and visit all `TypeId`s that
57/// it refers to, possibly transitively.
58pub trait TypeIdVisitor {
59 /// Callback invoked just before a type is visited.
60 ///
61 /// If this function returns `false` the type is not visited, otherwise it's
62 /// recursed into.
63 fn before_visit_type_id(&mut self, id: TypeId) -> bool {
64 let _ = id;
65 true
66 }
67
68 /// Callback invoked once a type is finished being visited.
69 fn after_visit_type_id(&mut self, id: TypeId) {
70 let _ = id;
71 }
72
73 fn visit_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
74 let iface = &resolve.interfaces[iface];
75 for (_, id) in iface.types.iter() {
76 self.visit_type_id(resolve, *id);
77 }
78 for (_, func) in iface.functions.iter() {
79 self.visit_func(resolve, func);
80 }
81 }
82
83 fn visit_world(&mut self, resolve: &Resolve, world: WorldId) {
84 let world = &resolve.worlds[world];
85 for (_, item) in world.imports.iter().chain(world.exports.iter()) {
86 self.visit_world_item(resolve, item);
87 }
88 }
89
90 fn visit_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
91 match item {
92 WorldItem::Interface { id, .. } => self.visit_interface(resolve, *id),
93 WorldItem::Function(f) => self.visit_func(resolve, f),
94 WorldItem::Type(t) => self.visit_type_id(resolve, *t),
95 }
96 }
97
98 fn visit_func(&mut self, resolve: &Resolve, func: &Function) {
99 match func.kind {
100 // This resource is live as it's attached to a static method but
101 // it's not guaranteed to be present in either params or results, so
102 // be sure to attach it here.
103 FunctionKind::Static(id) => self.visit_type_id(resolve, id),
104
105 // The resource these are attached to is in the params/results, so
106 // no need to re-add it here.
107 FunctionKind::Method(_) | FunctionKind::Constructor(_) => {}
108
109 FunctionKind::Freestanding => {}
110 }
111
112 for (_, ty) in func.params.iter() {
113 self.visit_type(resolve, ty);
114 }
115 for ty in func.results.iter_types() {
116 self.visit_type(resolve, ty);
117 }
118 }
119
120 fn visit_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
121 if self.before_visit_type_id(ty) {
122 self.visit_type_def(resolve, &resolve.types[ty]);
123 self.after_visit_type_id(ty);
124 }
125 }
126
127 fn visit_type_def(&mut self, resolve: &Resolve, ty: &TypeDef) {
128 match &ty.kind {
129 TypeDefKind::Type(t)
130 | TypeDefKind::List(t)
131 | TypeDefKind::Option(t)
132 | TypeDefKind::Future(Some(t))
133 | TypeDefKind::Stream(t) => self.visit_type(resolve, t),
134 TypeDefKind::Handle(handle) => match handle {
135 crate::Handle::Own(ty) => self.visit_type_id(resolve, *ty),
136 crate::Handle::Borrow(ty) => self.visit_type_id(resolve, *ty),
137 },
138 TypeDefKind::Resource => {}
139 TypeDefKind::Record(r) => {
140 for field in r.fields.iter() {
141 self.visit_type(resolve, &field.ty);
142 }
143 }
144 TypeDefKind::Tuple(r) => {
145 for ty in r.types.iter() {
146 self.visit_type(resolve, ty);
147 }
148 }
149 TypeDefKind::Variant(v) => {
150 for case in v.cases.iter() {
151 if let Some(ty) = &case.ty {
152 self.visit_type(resolve, ty);
153 }
154 }
155 }
156 TypeDefKind::Result(r) => {
157 if let Some(ty) = &r.ok {
158 self.visit_type(resolve, ty);
159 }
160 if let Some(ty) = &r.err {
161 self.visit_type(resolve, ty);
162 }
163 }
164 TypeDefKind::ErrorContext
165 | TypeDefKind::Flags(_)
166 | TypeDefKind::Enum(_)
167 | TypeDefKind::Future(None) => {}
168 TypeDefKind::Unknown => unreachable!(),
169 }
170 }
171
172 fn visit_type(&mut self, resolve: &Resolve, ty: &Type) {
173 match ty {
174 Type::Id(id) => self.visit_type_id(resolve, *id),
175 _ => {}
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::{LiveTypes, Resolve};
183
184 fn live(wit: &str, ty: &str) -> Vec<String> {
185 let mut resolve = Resolve::default();
186 resolve.push_str("test.wit", wit).unwrap();
187 let (_, interface) = resolve.interfaces.iter().next_back().unwrap();
188 let ty = interface.types[ty];
189 let mut live = LiveTypes::default();
190 live.add_type_id(&resolve, ty);
191
192 live.iter()
193 .filter_map(|ty| resolve.types[ty].name.clone())
194 .collect()
195 }
196
197 #[test]
198 fn no_deps() {
199 let types = live(
200 "
201 package foo:bar;
202
203 interface foo {
204 type t = u32;
205 }
206 ",
207 "t",
208 );
209 assert_eq!(types, ["t"]);
210 }
211
212 #[test]
213 fn one_dep() {
214 let types = live(
215 "
216 package foo:bar;
217
218 interface foo {
219 type t = u32;
220 type u = t;
221 }
222 ",
223 "u",
224 );
225 assert_eq!(types, ["t", "u"]);
226 }
227
228 #[test]
229 fn chain() {
230 let types = live(
231 "
232 package foo:bar;
233
234 interface foo {
235 resource t1;
236 record t2 {
237 x: t1,
238 }
239 variant t3 {
240 x(t2),
241 }
242 flags t4 { a }
243 enum t5 { a }
244 type t6 = tuple<t5, t4, t3>;
245 }
246 ",
247 "t6",
248 );
249 assert_eq!(types, ["t5", "t4", "t1", "t2", "t3", "t6"]);
250 }
251}
252