1use std::{
2 cell::RefCell,
3 collections::{HashMap, HashSet},
4 rc::Rc,
5};
6
7use itertools::Itertools;
8use wgsl_parse::{SyntaxNode, syntax::*};
9
10use crate::{Diagnostic, Error, Mangler, ResolveError, Resolver, SyntaxUtil, visit::Visit};
11
12type Imports = HashMap<Ident, ImportedItem>;
13type Modules = HashMap<ModulePath, Rc<RefCell<Module>>>;
14
15#[derive(Clone, Debug)]
16struct ImportedItem {
17 path: ModulePath,
18 ident: Ident, public: bool,
20}
21
22#[derive(Clone, Debug, thiserror::Error)]
24pub enum ImportError {
25 #[error("duplicate declaration of `{0}`")]
26 DuplicateSymbol(String),
27 #[error("{0}")]
28 ResolveError(#[from] ResolveError),
29 #[error("module `{0}` has no declaration `{1}`")]
30 MissingDecl(ModulePath, String),
31 #[error(
32 "import of `{0}` in module `{1}` is not `@publish`, but another module tried to import it"
33 )]
34 Private(String, ModulePath),
35}
36
37type E = ImportError;
38
39#[derive(Debug)]
40pub(crate) struct Module {
41 pub(crate) source: TranslationUnit,
42 pub(crate) path: ModulePath,
43 idents: HashMap<Ident, usize>, used_idents: RefCell<HashSet<Ident>>, imports: Imports,
46}
47
48impl Module {
49 fn new(source: TranslationUnit, path: ModulePath) -> Self {
50 let idents = source
51 .global_declarations
52 .iter()
53 .enumerate()
54 .filter_map(|(i, decl)| decl.ident().map(|id| (id, i)))
55 .collect::<HashMap<_, _>>();
56
57 Self {
58 source,
59 path,
60 idents,
61 used_idents: Default::default(),
62 imports: Default::default(),
63 }
64 }
65
66 fn find_decl(&self, ident: &Ident) -> Option<(&Ident, &usize)> {
67 self.idents.get_key_value(ident).or_else(|| {
68 self.idents
69 .iter()
70 .find(|(id, _)| *id.name() == *ident.name())
71 })
72 }
73 fn find_import(&self, ident: &Ident) -> Option<(&Ident, &ImportedItem)> {
74 self.imports.get_key_value(ident).or_else(|| {
75 self.imports
76 .iter()
77 .find(|(id, _)| *id.name() == *ident.name())
78 })
79 }
80}
81
82#[derive(Debug)]
83pub(crate) struct Resolutions {
84 modules: Modules,
85 order: Vec<ModulePath>,
86}
87
88impl Resolutions {
89 pub(crate) fn new(source: TranslationUnit, path: ModulePath) -> Self {
90 let mut resol = Self::new_uninit();
91 resol.push_module(Module::new(source, path));
92 resol
93 }
94 pub fn new_uninit() -> Self {
96 Resolutions {
97 modules: Default::default(),
98 order: Default::default(),
99 }
100 }
101 #[allow(unused)]
102 pub(crate) fn root_module(&self) -> Rc<RefCell<Module>> {
103 self.modules.get(self.root_path()).unwrap().clone() }
105 pub(crate) fn root_path(&self) -> &ModulePath {
106 self.order.first().unwrap() }
108 pub(crate) fn modules(&self) -> impl Iterator<Item = Rc<RefCell<Module>>> + '_ {
109 self.order.iter().map(|i| self.modules[i].clone())
110 }
111 pub(crate) fn push_module(&mut self, module: Module) -> Rc<RefCell<Module>> {
112 let path = module.path.clone();
113 let module = Rc::new(RefCell::new(module));
114 self.modules.insert(path.clone(), module.clone());
115 self.order.push(path);
116 module
117 }
118 pub(crate) fn into_module_order(self) -> Vec<ModulePath> {
119 self.order
120 }
121}
122
123fn err_with_module(e: Error, module: &Module, resolver: &impl Resolver) -> Error {
124 Error::from(
125 Diagnostic::from(e)
126 .with_module_path(module.path.clone(), resolver.display_name(&module.path)),
127 )
128}
129
130fn load_module<R: Resolver>(
132 path: &ModulePath,
133 resolutions: &mut Resolutions,
134 resolver: &R,
135 onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
136) -> Result<Rc<RefCell<Module>>, Error> {
137 if let Some(module) = resolutions.modules.get(path) {
138 return Ok(module.clone());
139 }
140
141 let source = resolver.resolve_module(path)?;
142 load_module_with_source(source, path, resolutions, resolver, onload)
143}
144
145fn load_module_with_source<R: Resolver>(
146 source: TranslationUnit,
147 path: &ModulePath,
148 resolutions: &mut Resolutions,
149 resolver: &R,
150 onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
151) -> Result<Rc<RefCell<Module>>, Error> {
152 let module = Module::new(source, path.clone());
153 let module = resolutions.push_module(module);
154
155 let imports = flatten_imports(&module.borrow().source.imports, path);
156 {
157 let mut module = module.borrow_mut();
158 module.imports = imports;
159 module.source.retarget_idents();
160 }
161
162 {
163 let module = module.borrow();
164 onload(&module, resolutions, resolver)
165 .map_err(|e| err_with_module(e, &module, resolver))?;
166 }
167
168 Ok(module)
169}
170
171fn resolve_decl<R: Resolver>(
174 module: &Module,
175 ident: &Ident,
176 resolutions: &mut Resolutions,
177 resolver: &R,
178 onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
179) -> Result<(), Error> {
180 if let Some((_, n)) = module.find_decl(ident) {
181 let decl = module.source.global_declarations.get(*n).unwrap().node();
182 if let Some(ident) = decl.ident() {
183 if !module.used_idents.borrow_mut().insert(ident) {
184 return Ok(());
185 }
186 }
187
188 for ty in Visit::<TypeExpression>::visit(decl) {
189 resolve_ty(module, ty, resolutions, resolver, onload)?;
190 }
191 Ok(())
192 } else if let Some((_, item)) = module.find_import(ident) {
193 if item.public {
195 let ext_mod = load_module(&item.path, resolutions, resolver, onload)?;
197 let ext_mod = ext_mod.borrow();
198 resolve_decl(&ext_mod, &item.ident, resolutions, resolver, onload)
199 .map_err(|e| err_with_module(e, &ext_mod, resolver))
200 } else {
201 Err(E::Private(ident.to_string(), module.path.clone()).into())
202 }
203 } else {
204 Err(E::MissingDecl(module.path.clone(), ident.to_string()).into())
205 }
206}
207
208fn resolve_ty<R: Resolver>(
210 module: &Module,
211 ty: &TypeExpression,
212 resolutions: &mut Resolutions,
213 resolver: &R,
214 onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
215) -> Result<(), Error> {
216 for ty in Visit::<TypeExpression>::visit(ty) {
218 resolve_ty(module, ty, resolutions, resolver, onload)?;
219 }
220
221 let (ext_path, ext_id) = if let Some(path) = &ty.path {
223 let path = resolve_inline_path(path, &module.path, &module.imports);
224 (path, &ty.ident)
225 } else if let Some(item) = module.imports.get(&ty.ident) {
226 (item.path.clone(), &item.ident)
227 } else {
228 if module.idents.contains_key(&ty.ident) {
230 resolve_decl(module, &ty.ident, resolutions, resolver, onload)?;
231 }
232 return Ok(());
233 };
234
235 if ext_path == module.path {
238 if module.idents.contains_key(&ty.ident) {
239 return Ok(());
240 } else {
241 return Err(E::MissingDecl(ext_path, ty.ident.to_string()).into());
242 }
243 }
244
245 let ext_mod = load_module(&ext_path, resolutions, resolver, &onload)?;
247 let ext_mod = ext_mod.borrow();
248
249 resolve_decl(&ext_mod, ext_id, resolutions, resolver, onload)
251 .map_err(|e| err_with_module(e, &ext_mod, resolver))
252}
253
254pub fn resolve_lazy<'a>(
272 keep: impl IntoIterator<Item = &'a Ident>,
273 source: TranslationUnit,
274 path: &ModulePath,
275 resolver: &impl Resolver,
276) -> Result<Resolutions, Error> {
277 fn resolve_module(
278 module: &Module,
279 resolutions: &mut Resolutions,
280 resolver: &impl Resolver,
281 ) -> Result<(), Error> {
282 let const_asserts = module
285 .source
286 .global_declarations
287 .iter()
288 .filter(|decl| decl.is_const_assert());
289
290 for decl in const_asserts {
291 for ty in Visit::<TypeExpression>::visit(decl.node()) {
292 resolve_ty(module, ty, resolutions, resolver, &resolve_module)?;
293 }
294 }
295
296 Ok(())
297 }
298
299 let mut resolutions = Resolutions::new_uninit();
300 let module =
301 load_module_with_source(source, path, &mut resolutions, resolver, &resolve_module)?;
302
303 {
304 let module = module.borrow();
305 for id in keep {
306 resolve_decl(&module, id, &mut resolutions, resolver, &resolve_module)
307 .map_err(|e| err_with_module(e, &module, resolver))?;
308 }
309 }
310
311 resolutions.retarget()?;
312 Ok(resolutions)
313}
314
315pub fn resolve_eager(
317 source: TranslationUnit,
318 path: &ModulePath,
319 resolver: &impl Resolver,
320) -> Result<Resolutions, Error> {
321 fn resolve_module(
322 module: &Module,
323 resolutions: &mut Resolutions,
324 resolver: &impl Resolver,
325 ) -> Result<(), Error> {
326 for item in module.imports.values() {
328 load_module(&item.path, resolutions, resolver, &resolve_module)?;
329 }
330
331 for decl in &module.source.global_declarations {
332 if let Some(ident) = decl.ident() {
333 resolve_decl(module, &ident, resolutions, resolver, &resolve_module)?;
334 } else {
335 for ty in Visit::<TypeExpression>::visit(decl.node()) {
336 resolve_ty(module, ty, resolutions, resolver, &resolve_module)?;
337 }
338 }
339 }
340
341 Ok(())
342 }
343
344 let mut resolutions = Resolutions::new_uninit();
345 load_module_with_source(source, path, &mut resolutions, resolver, &resolve_module)?;
346
347 resolutions.retarget()?;
348 Ok(resolutions)
349}
350
351fn flatten_imports(imports: &[ImportStatement], path: &ModulePath) -> Imports {
353 fn rec(content: &ImportContent, path: ModulePath, public: bool, res: &mut Imports) {
354 match content {
355 ImportContent::Item(item) => {
356 let ident = item.rename.as_ref().unwrap_or(&item.ident).clone();
357 res.insert(
358 ident,
359 ImportedItem {
360 path,
361 ident: item.ident.clone(),
362 public,
363 },
364 );
365 }
366 ImportContent::Collection(coll) => {
367 for import in coll {
368 let path = path.clone().join(import.path.iter().cloned());
369 rec(&import.content, path, public, res);
370 }
371 }
372 }
373 }
374
375 let mut res = Imports::default();
376
377 for import in imports {
378 let public = import.attributes.iter().any(|attr| attr.is_publish());
379 match &import.path {
380 Some(import_path) => {
381 let path = path.join_path(import_path);
382 rec(&import.content, path, public, &mut res);
383 }
384 None => {
385 match &import.content {
388 ImportContent::Item(_) => {
389 }
392 ImportContent::Collection(coll) => {
393 for import in coll {
394 let mut components = import.path.iter().cloned();
395 if let Some(pkg_name) = components.next() {
396 let path = ModulePath::new(
398 PathOrigin::Package(pkg_name),
399 components.collect_vec(),
400 );
401 rec(&import.content, path, public, &mut res);
402 }
403 }
404 }
405 }
406 }
407 }
408 }
409
410 res
411}
412
413fn resolve_inline_path(
418 path: &ModulePath,
419 parent_path: &ModulePath,
420 imports: &Imports,
421) -> ModulePath {
422 match &path.origin {
423 PathOrigin::Package(pkg_name) => {
424 let imported_item = imports.iter().find(|(ident, _)| *ident.name() == *pkg_name);
426
427 if let Some((_, ext_item)) = imported_item {
428 let mut res = ext_item.path.clone(); res.push(&ext_item.ident.name()); res.join(path.components.iter().cloned())
433 } else {
434 parent_path.join_path(path)
435 }
436 }
437 _ => parent_path.join_path(path),
438 }
439}
440
441pub(crate) fn mangle_decls<'a>(
442 wgsl: &'a mut TranslationUnit,
443 path: &'a ModulePath,
444 mangler: &impl Mangler,
445) {
446 wgsl.global_declarations
447 .iter_mut()
448 .filter_map(|decl| decl.ident())
449 .for_each(|mut ident| {
450 let new_name = mangler.mangle(path, &ident.name());
451 ident.rename(new_name.clone());
452 })
453}
454
455impl Resolutions {
456 fn retarget(&self) -> Result<(), Error> {
465 fn find_ext_ident(
466 modules: &Modules,
467 src_path: &ModulePath,
468 src_id: &Ident,
469 ) -> Option<Ident> {
470 let module = modules.get(src_path)?;
472 let module = module.borrow();
475
476 module
477 .find_decl(src_id)
478 .map(|(id, _)| id.clone())
479 .or_else(|| {
480 module
482 .find_import(src_id)
483 .and_then(|(_, item)| find_ext_ident(modules, &item.path, &item.ident))
484 })
485 }
486
487 fn retarget_ty(
488 modules: &Modules,
489 module_path: &ModulePath,
490 module_imports: &Imports,
491 module_idents: &HashMap<Ident, usize>,
492 ty: &mut TypeExpression,
493 ) -> Result<(), Error> {
494 for ty in Visit::<TypeExpression>::visit_mut(ty) {
496 retarget_ty(modules, module_path, module_imports, module_idents, ty)?;
497 }
498
499 let (ext_path, ext_id) = if let Some(path) = &ty.path {
500 let res = resolve_inline_path(path, module_path, module_imports);
501 (res, &ty.ident)
502 } else if let Some(item) = module_imports.get(&ty.ident) {
503 (item.path.clone(), &item.ident)
504 } else {
505 return Ok(());
507 };
508
509 if ext_path == *module_path {
512 let local_id = module_idents
513 .iter()
514 .find(|(id, _)| *id.name() == *ext_id.name())
515 .map(|(id, _)| id.clone())
516 .ok_or_else(|| E::MissingDecl(ext_path, ext_id.to_string()))?;
517 ty.path = None;
518 ty.ident = local_id;
519 }
520 else if let Some(ext_id) = find_ext_ident(modules, &ext_path, ext_id) {
522 ty.path = None;
523 ty.ident = ext_id;
524 }
525 else {
528 return Err(E::MissingDecl(ext_path, ext_id.to_string()).into());
529 }
530
531 Ok(())
532 }
533
534 for module in self.modules.values() {
535 let mut module = module.borrow_mut();
536 let module = &mut *module;
537
538 for decl in &mut module.source.global_declarations {
539 if let Some(id) = decl.ident() {
542 if !module.used_idents.borrow().contains(&id) {
543 continue;
544 }
545 }
546
547 for ty in Visit::<TypeExpression>::visit_mut(decl.node_mut()) {
548 retarget_ty(
549 &self.modules,
550 &module.path,
551 &module.imports,
552 &module.idents,
553 ty,
554 )?;
555 }
556 }
557 }
558
559 Ok(())
560 }
561
562 pub(crate) fn mangle(&mut self, mangler: &impl Mangler, mangle_root: bool) {
566 let root_path = self.root_path().clone();
567 for (path, module) in self.modules.iter_mut() {
568 if mangle_root || path != &root_path {
569 let mut module = module.borrow_mut();
570 mangle_decls(&mut module.source, path, mangler);
571 }
572 }
573 }
574
575 pub(crate) fn assemble(&self, strip: bool) -> TranslationUnit {
578 let mut wesl = TranslationUnit::default();
579 for module in self.modules() {
580 let module = module.borrow();
581 if strip {
582 wesl.global_declarations.extend(
583 module
584 .source
585 .global_declarations
586 .iter()
587 .filter(|decl| {
588 decl.is_const_assert()
589 || decl
590 .ident()
591 .is_some_and(|id| module.used_idents.borrow().contains(&id))
592 })
593 .cloned(),
594 );
595 } else {
596 wesl.global_declarations
597 .extend(module.source.global_declarations.clone());
598 }
599 wesl.global_directives
600 .extend(module.source.global_directives.clone());
601 }
602 wesl.global_directives.dedup();
607 wesl
608 }
609}