wesl/
import.rs

1use std::{
2    cell::RefCell,
3    collections::{HashMap, HashSet},
4    rc::Rc,
5};
6
7use itertools::Itertools;
8use wgsl_parse::{SyntaxNode, syntax::*};
9
10use crate::{Diagnostic, Error, Mangler, ResolveError, Resolver, SyntaxUtil, visit::Visit};
11
12type Imports = HashMap<Ident, ImportedItem>;
13type Modules = HashMap<ModulePath, Rc<RefCell<Module>>>;
14
15#[derive(Clone, Debug)]
16struct ImportedItem {
17    path: ModulePath,
18    ident: Ident, // this is the ident's original name before `as` renaming.
19    public: bool,
20}
21
22/// Error produced during import resolution.
23#[derive(Clone, Debug, thiserror::Error)]
24pub enum ImportError {
25    #[error("duplicate declaration of `{0}`")]
26    DuplicateSymbol(String),
27    #[error("{0}")]
28    ResolveError(#[from] ResolveError),
29    #[error("module `{0}` has no declaration `{1}`")]
30    MissingDecl(ModulePath, String),
31    #[error(
32        "import of `{0}` in module `{1}` is not `@publish`, but another module tried to import it"
33    )]
34    Private(String, ModulePath),
35}
36
37type E = ImportError;
38
39#[derive(Debug)]
40pub(crate) struct Module {
41    pub(crate) source: TranslationUnit,
42    pub(crate) path: ModulePath,
43    idents: HashMap<Ident, usize>,        // lookup (ident, decl_index)
44    used_idents: RefCell<HashSet<Ident>>, // used idents that have already been usage-analyzed
45    imports: Imports,
46}
47
48impl Module {
49    fn new(source: TranslationUnit, path: ModulePath) -> Self {
50        let idents = source
51            .global_declarations
52            .iter()
53            .enumerate()
54            .filter_map(|(i, decl)| decl.ident().map(|id| (id, i)))
55            .collect::<HashMap<_, _>>();
56
57        Self {
58            source,
59            path,
60            idents,
61            used_idents: Default::default(),
62            imports: Default::default(),
63        }
64    }
65
66    fn find_decl(&self, ident: &Ident) -> Option<(&Ident, &usize)> {
67        self.idents.get_key_value(ident).or_else(|| {
68            self.idents
69                .iter()
70                .find(|(id, _)| *id.name() == *ident.name())
71        })
72    }
73    fn find_import(&self, ident: &Ident) -> Option<(&Ident, &ImportedItem)> {
74        self.imports.get_key_value(ident).or_else(|| {
75            self.imports
76                .iter()
77                .find(|(id, _)| *id.name() == *ident.name())
78        })
79    }
80}
81
82#[derive(Debug)]
83pub(crate) struct Resolutions {
84    modules: Modules,
85    order: Vec<ModulePath>,
86}
87
88impl Resolutions {
89    pub(crate) fn new(source: TranslationUnit, path: ModulePath) -> Self {
90        let mut resol = Self::new_uninit();
91        resol.push_module(Module::new(source, path));
92        resol
93    }
94    /// Warning: you *must* call `push_module` right after this.
95    pub fn new_uninit() -> Self {
96        Resolutions {
97            modules: Default::default(),
98            order: Default::default(),
99        }
100    }
101    #[allow(unused)]
102    pub(crate) fn root_module(&self) -> Rc<RefCell<Module>> {
103        self.modules.get(self.root_path()).unwrap().clone() // safety: new() requires push_module
104    }
105    pub(crate) fn root_path(&self) -> &ModulePath {
106        self.order.first().unwrap() // safety: new() requires push_module
107    }
108    pub(crate) fn modules(&self) -> impl Iterator<Item = Rc<RefCell<Module>>> + '_ {
109        self.order.iter().map(|i| self.modules[i].clone())
110    }
111    pub(crate) fn push_module(&mut self, module: Module) -> Rc<RefCell<Module>> {
112        let path = module.path.clone();
113        let module = Rc::new(RefCell::new(module));
114        self.modules.insert(path.clone(), module.clone());
115        self.order.push(path);
116        module
117    }
118    pub(crate) fn into_module_order(self) -> Vec<ModulePath> {
119        self.order
120    }
121}
122
123fn err_with_module(e: Error, module: &Module, resolver: &impl Resolver) -> Error {
124    Error::from(
125        Diagnostic::from(e)
126            .with_module_path(module.path.clone(), resolver.display_name(&module.path)),
127    )
128}
129
130/// get or load a module with the resolver.
131fn load_module<R: Resolver>(
132    path: &ModulePath,
133    resolutions: &mut Resolutions,
134    resolver: &R,
135    onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
136) -> Result<Rc<RefCell<Module>>, Error> {
137    if let Some(module) = resolutions.modules.get(path) {
138        return Ok(module.clone());
139    }
140
141    let source = resolver.resolve_module(path)?;
142    load_module_with_source(source, path, resolutions, resolver, onload)
143}
144
145fn load_module_with_source<R: Resolver>(
146    source: TranslationUnit,
147    path: &ModulePath,
148    resolutions: &mut Resolutions,
149    resolver: &R,
150    onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
151) -> Result<Rc<RefCell<Module>>, Error> {
152    let module = Module::new(source, path.clone());
153    let module = resolutions.push_module(module);
154
155    let imports = flatten_imports(&module.borrow().source.imports, path);
156    {
157        let mut module = module.borrow_mut();
158        module.imports = imports;
159        module.source.retarget_idents();
160    }
161
162    {
163        let module = module.borrow();
164        onload(&module, resolutions, resolver)
165            .map_err(|e| err_with_module(e, &module, resolver))?;
166    }
167
168    Ok(module)
169}
170
171/// load the modules that a declaration (named by its identifier) refers to, recursively.
172/// the identifier must not be a builtin.
173fn resolve_decl<R: Resolver>(
174    module: &Module,
175    ident: &Ident,
176    resolutions: &mut Resolutions,
177    resolver: &R,
178    onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
179) -> Result<(), Error> {
180    if let Some((_, n)) = module.find_decl(ident) {
181        let decl = module.source.global_declarations.get(*n).unwrap().node();
182        if let Some(ident) = decl.ident() {
183            if !module.used_idents.borrow_mut().insert(ident) {
184                return Ok(());
185            }
186        }
187
188        for ty in Visit::<TypeExpression>::visit(decl) {
189            resolve_ty(module, ty, resolutions, resolver, onload)?;
190        }
191        Ok(())
192    } else if let Some((_, item)) = module.find_import(ident) {
193        // the declaration can be a re-export (`@publish import`)
194        if item.public {
195            // load the external module for this imported item
196            let ext_mod = load_module(&item.path, resolutions, resolver, onload)?;
197            let ext_mod = ext_mod.borrow();
198            resolve_decl(&ext_mod, &item.ident, resolutions, resolver, onload)
199                .map_err(|e| err_with_module(e, &ext_mod, resolver))
200        } else {
201            Err(E::Private(ident.to_string(), module.path.clone()).into())
202        }
203    } else {
204        Err(E::MissingDecl(module.path.clone(), ident.to_string()).into())
205    }
206}
207
208/// load the modules that a TypeExpression refers to, recursively.
209fn resolve_ty<R: Resolver>(
210    module: &Module,
211    ty: &TypeExpression,
212    resolutions: &mut Resolutions,
213    resolver: &R,
214    onload: &impl Fn(&Module, &mut Resolutions, &R) -> Result<(), Error>,
215) -> Result<(), Error> {
216    // first, the recursive call
217    for ty in Visit::<TypeExpression>::visit(ty) {
218        resolve_ty(module, ty, resolutions, resolver, onload)?;
219    }
220
221    // get the path and identifier referred to by the TypeExpression, if it is imported
222    let (ext_path, ext_id) = if let Some(path) = &ty.path {
223        let path = resolve_inline_path(path, &module.path, &module.imports);
224        (path, &ty.ident)
225    } else if let Some(item) = module.imports.get(&ty.ident) {
226        (item.path.clone(), &item.ident)
227    } else {
228        // This is a local declaration or a builtin, we mark the ident as used.
229        if module.idents.contains_key(&ty.ident) {
230            resolve_decl(module, &ty.ident, resolutions, resolver, onload)?;
231        }
232        return Ok(());
233    };
234
235    // if the import path points to a local declaration, we just check that it exists
236    // and we're done.
237    if ext_path == module.path {
238        if module.idents.contains_key(&ty.ident) {
239            return Ok(());
240        } else {
241            return Err(E::MissingDecl(ext_path, ty.ident.to_string()).into());
242        }
243    }
244
245    // load the external module for this imported item
246    let ext_mod = load_module(&ext_path, resolutions, resolver, &onload)?;
247    let ext_mod = ext_mod.borrow();
248
249    // and ensure the declaration's dependencies are resolved too
250    resolve_decl(&ext_mod, ext_id, resolutions, resolver, onload)
251        .map_err(|e| err_with_module(e, &ext_mod, resolver))
252}
253
254// XXX: it's quite messy.
255/// Load all modules "used" transitively by the root module. Make external idents point at
256/// the right declaration in the external module.
257///
258/// It is "lazy" because external modules are loaded only if used by the `keep` declarations
259/// or module-scope `const_assert`s.
260///
261/// This approach is only valid when stripping is enabled. Otherwise, unused declarations
262/// may refer to declarations in unused modules, and mangling will panic.
263///
264/// "used": used declarations in the root module are the `keep` parameter. Used declarations
265/// in other modules are those reached by `keep` declarations, recursively.
266/// Module-scope `const_assert`s are always included.
267///
268/// Returns a list of [`Module`]s with the list of their "used" idents.
269///
270/// See also: [`resolve_eager`]
271pub fn resolve_lazy<'a>(
272    keep: impl IntoIterator<Item = &'a Ident>,
273    source: TranslationUnit,
274    path: &ModulePath,
275    resolver: &impl Resolver,
276) -> Result<Resolutions, Error> {
277    fn resolve_module(
278        module: &Module,
279        resolutions: &mut Resolutions,
280        resolver: &impl Resolver,
281    ) -> Result<(), Error> {
282        // const_asserts of used modules must be included.
283        // https://github.com/wgsl-tooling-wg/wesl-spec/issues/66
284        let const_asserts = module
285            .source
286            .global_declarations
287            .iter()
288            .filter(|decl| decl.is_const_assert());
289
290        for decl in const_asserts {
291            for ty in Visit::<TypeExpression>::visit(decl.node()) {
292                resolve_ty(module, ty, resolutions, resolver, &resolve_module)?;
293            }
294        }
295
296        Ok(())
297    }
298
299    let mut resolutions = Resolutions::new_uninit();
300    let module =
301        load_module_with_source(source, path, &mut resolutions, resolver, &resolve_module)?;
302
303    {
304        let module = module.borrow();
305        for id in keep {
306            resolve_decl(&module, id, &mut resolutions, resolver, &resolve_module)
307                .map_err(|e| err_with_module(e, &module, resolver))?;
308        }
309    }
310
311    resolutions.retarget()?;
312    Ok(resolutions)
313}
314
315/// Load all [`Module`]s referenced by the root module.
316pub fn resolve_eager(
317    source: TranslationUnit,
318    path: &ModulePath,
319    resolver: &impl Resolver,
320) -> Result<Resolutions, Error> {
321    fn resolve_module(
322        module: &Module,
323        resolutions: &mut Resolutions,
324        resolver: &impl Resolver,
325    ) -> Result<(), Error> {
326        // resolve all module imports
327        for item in module.imports.values() {
328            load_module(&item.path, resolutions, resolver, &resolve_module)?;
329        }
330
331        for decl in &module.source.global_declarations {
332            if let Some(ident) = decl.ident() {
333                resolve_decl(module, &ident, resolutions, resolver, &resolve_module)?;
334            } else {
335                for ty in Visit::<TypeExpression>::visit(decl.node()) {
336                    resolve_ty(module, ty, resolutions, resolver, &resolve_module)?;
337                }
338            }
339        }
340
341        Ok(())
342    }
343
344    let mut resolutions = Resolutions::new_uninit();
345    load_module_with_source(source, path, &mut resolutions, resolver, &resolve_module)?;
346
347    resolutions.retarget()?;
348    Ok(resolutions)
349}
350
351/// Flatten imports to a list.
352fn flatten_imports(imports: &[ImportStatement], path: &ModulePath) -> Imports {
353    fn rec(content: &ImportContent, path: ModulePath, public: bool, res: &mut Imports) {
354        match content {
355            ImportContent::Item(item) => {
356                let ident = item.rename.as_ref().unwrap_or(&item.ident).clone();
357                res.insert(
358                    ident,
359                    ImportedItem {
360                        path,
361                        ident: item.ident.clone(),
362                        public,
363                    },
364                );
365            }
366            ImportContent::Collection(coll) => {
367                for import in coll {
368                    let path = path.clone().join(import.path.iter().cloned());
369                    rec(&import.content, path, public, res);
370                }
371            }
372        }
373    }
374
375    let mut res = Imports::default();
376
377    for import in imports {
378        let public = import.attributes.iter().any(|attr| attr.is_publish());
379        match &import.path {
380            Some(import_path) => {
381                let path = path.join_path(import_path);
382                rec(&import.content, path, public, &mut res);
383            }
384            None => {
385                // this covers two cases: `import foo;` and `import {foo, ..};`.
386                // COMBAK: these edge-cases smell
387                match &import.content {
388                    ImportContent::Item(_) => {
389                        // `import foo`, this import statement does nothing currently.
390                        // In the future, it may become a visibility/re-export mechanism.
391                    }
392                    ImportContent::Collection(coll) => {
393                        for import in coll {
394                            let mut components = import.path.iter().cloned();
395                            if let Some(pkg_name) = components.next() {
396                                // `import {foo::bar}`, foo becomes the package name.
397                                let path = ModulePath::new(
398                                    PathOrigin::Package(pkg_name),
399                                    components.collect_vec(),
400                                );
401                                rec(&import.content, path, public, &mut res);
402                            }
403                        }
404                    }
405                }
406            }
407        }
408    }
409
410    res
411}
412
413/// Finds the normalized module path for an inline import.
414///
415/// Inline imports differ from import statements only in case of package imports:
416/// the package component may refer to a local import shadowing the package name.
417fn resolve_inline_path(
418    path: &ModulePath,
419    parent_path: &ModulePath,
420    imports: &Imports,
421) -> ModulePath {
422    match &path.origin {
423        PathOrigin::Package(pkg_name) => {
424            // the path could be either a package, of referencing an imported module alias.
425            let imported_item = imports.iter().find(|(ident, _)| *ident.name() == *pkg_name);
426
427            if let Some((_, ext_item)) = imported_item {
428                // this inline path references an imported item. Example:
429                // import a::b::c as foo; foo::bar::baz() => a::b::c::bar::baz()
430                let mut res = ext_item.path.clone(); // a::b
431                res.push(&ext_item.ident.name()); // c
432                res.join(path.components.iter().cloned())
433            } else {
434                parent_path.join_path(path)
435            }
436        }
437        _ => parent_path.join_path(path),
438    }
439}
440
441pub(crate) fn mangle_decls<'a>(
442    wgsl: &'a mut TranslationUnit,
443    path: &'a ModulePath,
444    mangler: &impl Mangler,
445) {
446    wgsl.global_declarations
447        .iter_mut()
448        .filter_map(|decl| decl.ident())
449        .for_each(|mut ident| {
450            let new_name = mangler.mangle(path, &ident.name());
451            ident.rename(new_name.clone());
452        })
453}
454
455impl Resolutions {
456    /// Retarget used identifiers to point at the corresponding declaration.
457    ///
458    /// We call this after resolve, because it is mutating the modules, and we want to keep
459    /// mutations and lookups separate if possible, to avoid multiple mut borrows.
460    ///
461    /// Panics
462    /// * if an identifier has no corresponding declaration.
463    /// * if a module is already borrowed.
464    fn retarget(&self) -> Result<(), Error> {
465        fn find_ext_ident(
466            modules: &Modules,
467            src_path: &ModulePath,
468            src_id: &Ident,
469        ) -> Option<Ident> {
470            // load the external module for this external ident
471            let module = modules.get(src_path)?;
472            // SAFETY: since this is an external ident, it cannot be in the currently
473            // borrowed module.
474            let module = module.borrow();
475
476            module
477                .find_decl(src_id)
478                .map(|(id, _)| id.clone())
479                .or_else(|| {
480                    // or it could be a re-exported import with `@publish`
481                    module
482                        .find_import(src_id)
483                        .and_then(|(_, item)| find_ext_ident(modules, &item.path, &item.ident))
484                })
485        }
486
487        fn retarget_ty(
488            modules: &Modules,
489            module_path: &ModulePath,
490            module_imports: &Imports,
491            module_idents: &HashMap<Ident, usize>,
492            ty: &mut TypeExpression,
493        ) -> Result<(), Error> {
494            // first the recursive call
495            for ty in Visit::<TypeExpression>::visit_mut(ty) {
496                retarget_ty(modules, module_path, module_imports, module_idents, ty)?;
497            }
498
499            let (ext_path, ext_id) = if let Some(path) = &ty.path {
500                let res = resolve_inline_path(path, module_path, module_imports);
501                (res, &ty.ident)
502            } else if let Some(item) = module_imports.get(&ty.ident) {
503                (item.path.clone(), &item.ident)
504            } else {
505                // points to a local decl, we stop here.
506                return Ok(());
507            };
508
509            // if the import path points to a local decl.
510            // this must be a special case to avoid 2 mut borrows of the current module.
511            if ext_path == *module_path {
512                let local_id = module_idents
513                    .iter()
514                    .find(|(id, _)| *id.name() == *ext_id.name())
515                    .map(|(id, _)| id.clone())
516                    .ok_or_else(|| E::MissingDecl(ext_path, ext_id.to_string()))?;
517                ty.path = None;
518                ty.ident = local_id;
519            }
520            // get the ident of the external declaration pointed to by the type
521            else if let Some(ext_id) = find_ext_ident(modules, &ext_path, ext_id) {
522                ty.path = None;
523                ty.ident = ext_id;
524            }
525            // the imported ident is used, but has no declaration!
526            // this code path should not be reached, as this is already checked in resolve().
527            else {
528                return Err(E::MissingDecl(ext_path, ext_id.to_string()).into());
529            }
530
531            Ok(())
532        }
533
534        for module in self.modules.values() {
535            let mut module = module.borrow_mut();
536            let module = &mut *module;
537
538            for decl in &mut module.source.global_declarations {
539                // we only retarged used declarations. Other declarations are not checked.
540                // unused declarations can even contain invalid code.
541                if let Some(id) = decl.ident() {
542                    if !module.used_idents.borrow().contains(&id) {
543                        continue;
544                    }
545                }
546
547                for ty in Visit::<TypeExpression>::visit_mut(decl.node_mut()) {
548                    retarget_ty(
549                        &self.modules,
550                        &module.path,
551                        &module.imports,
552                        &module.idents,
553                        ty,
554                    )?;
555                }
556            }
557        }
558
559        Ok(())
560    }
561
562    /// Mangle all declarations in all modules. Should be called after [`Self::retarget`].
563    ///
564    /// Panics if a module is already borrowed.
565    pub(crate) fn mangle(&mut self, mangler: &impl Mangler, mangle_root: bool) {
566        let root_path = self.root_path().clone();
567        for (path, module) in self.modules.iter_mut() {
568            if mangle_root || path != &root_path {
569                let mut module = module.borrow_mut();
570                mangle_decls(&mut module.source, path, mangler);
571            }
572        }
573    }
574
575    /// Merge all declarations into a single module. If the `strip` flag is set, it will
576    /// copy over only used declarations.
577    pub(crate) fn assemble(&self, strip: bool) -> TranslationUnit {
578        let mut wesl = TranslationUnit::default();
579        for module in self.modules() {
580            let module = module.borrow();
581            if strip {
582                wesl.global_declarations.extend(
583                    module
584                        .source
585                        .global_declarations
586                        .iter()
587                        .filter(|decl| {
588                            decl.is_const_assert()
589                                || decl
590                                    .ident()
591                                    .is_some_and(|id| module.used_idents.borrow().contains(&id))
592                        })
593                        .cloned(),
594                );
595            } else {
596                wesl.global_declarations
597                    .extend(module.source.global_declarations.clone());
598            }
599            wesl.global_directives
600                .extend(module.source.global_directives.clone());
601        }
602        // TODO: <https://github.com/wgsl-tooling-wg/wesl-spec/issues/71>
603        // currently the behavior is:
604        // * include all directives used (if strip)
605        // * include all directives (if not strip)
606        wesl.global_directives.dedup();
607        wesl
608    }
609}