wgsl_parse/
syntax_impl.rs

1use itertools::Itertools;
2
3use crate::span::Spanned;
4
5use super::syntax::*;
6
7impl TranslationUnit {
8    /// New empty [`TranslationUnit`]
9    pub fn new() -> Self {
10        Self::default()
11    }
12
13    /// Remove all [`GlobalDeclaration::Void`] and [`Statement::Void`]
14    pub fn remove_voids(&mut self) {
15        self.global_declarations.retain_mut(|decl| match decl {
16            GlobalDeclaration::Void => false,
17            _ => {
18                decl.remove_voids();
19                true
20            }
21        })
22    }
23}
24
25#[cfg(feature = "imports")]
26impl ModulePath {
27    /// Create a new module path from components.
28    ///
29    /// Precondition: the path components must be valid WGSL identifiers.
30    pub fn new(origin: PathOrigin, components: Vec<String>) -> Self {
31        Self { origin, components }
32    }
33    /// Create a new module path from a filesystem path.
34    ///
35    /// * Paths with a root (leading `/` on Unix) produce `package::` paths.
36    /// * Relative paths (starting with `.` or `..`) produce `self::` or `super::` paths.
37    /// * The file extension is ignored.
38    /// * The path is canonicalized and to do so it does NOT follow symlinks.
39    ///
40    /// Precondition: the path components must be valid WGSL identifiers.
41    pub fn from_path(path: impl AsRef<std::path::Path>) -> Self {
42        use std::path::Component;
43        let mut origin = PathOrigin::Package;
44        let mut components = Vec::new();
45
46        for comp in path.as_ref().with_extension("").components() {
47            match comp {
48                Component::Prefix(_) => {}
49                Component::RootDir => origin = PathOrigin::Absolute,
50                Component::CurDir => {
51                    if components.is_empty() && origin.is_package() {
52                        origin = PathOrigin::Relative(0);
53                    }
54                }
55                Component::ParentDir => {
56                    if components.is_empty() {
57                        if let PathOrigin::Relative(n) = &mut origin {
58                            *n += 1;
59                        } else {
60                            origin = PathOrigin::Relative(1)
61                        }
62                    } else {
63                        components.pop();
64                    }
65                }
66                Component::Normal(comp) => components.push(comp.to_string_lossy().to_string()),
67            }
68        }
69
70        Self { origin, components }
71    }
72    /// Create a `PathBuf` from a `ModulePath`.
73    ///
74    /// * `package::` paths are rooted (start with `/`).
75    /// * self::` or `super::` are relative (starting with `.` or `..`)`.
76    /// * There is no file extension.
77    pub fn to_path_buf(&self) -> std::path::PathBuf {
78        use std::path::PathBuf;
79        let mut fs_path = match self.origin {
80            PathOrigin::Absolute => PathBuf::from("/"),
81            PathOrigin::Relative(0) => PathBuf::from("."),
82            PathOrigin::Relative(n) => PathBuf::from_iter((0..n).map(|_| "..")),
83            PathOrigin::Package => PathBuf::new(),
84        };
85        fs_path.extend(&self.components);
86        fs_path
87    }
88    /// Append a component to the path.
89    ///
90    /// Precondition: the `item` must be a valid WGSL identifier.
91    pub fn push(&mut self, item: &str) {
92        self.components.push(item.to_string());
93    }
94    /// Get the first component of the module path.
95    pub fn first(&self) -> Option<&str> {
96        self.components.first().map(String::as_str)
97    }
98    /// Get the last component of the module path.
99    pub fn last(&self) -> Option<&str> {
100        self.components.last().map(String::as_str)
101    }
102    /// Append `suffix` to the module path.
103    pub fn join(mut self, suffix: impl IntoIterator<Item = String>) -> Self {
104        self.components.extend(suffix);
105        self
106    }
107    /// Append `suffix` to the module path.
108    /// the suffix must be a relative module path.
109    pub fn join_path(&self, path: &Self) -> Option<Self> {
110        match path.origin {
111            PathOrigin::Relative(n) => {
112                let to_keep = self.components.len().max(n) - n;
113                let components = self
114                    .components
115                    .iter()
116                    .take(to_keep)
117                    .chain(&path.components)
118                    .cloned()
119                    .collect_vec();
120                let origin = match self.origin {
121                    PathOrigin::Absolute | PathOrigin::Package => {
122                        if n > self.components.len() {
123                            PathOrigin::Relative(n - self.components.len())
124                        } else {
125                            self.origin
126                        }
127                    }
128                    PathOrigin::Relative(m) => {
129                        if n > self.components.len() {
130                            PathOrigin::Relative(m + n - self.components.len())
131                        } else {
132                            self.origin
133                        }
134                    }
135                };
136                Some(Self { origin, components })
137            }
138            _ => None,
139        }
140    }
141    pub fn starts_with(&self, prefix: &Self) -> bool {
142        self.origin == prefix.origin
143            && prefix.components.len() >= self.components.len()
144            && prefix
145                .components
146                .iter()
147                .zip(&self.components)
148                .all(|(a, b)| a == b)
149    }
150    pub fn is_empty(&self) -> bool {
151        self.origin.is_package() && self.components.is_empty()
152    }
153}
154
155#[cfg(feature = "imports")]
156impl Default for ModulePath {
157    /// The path that is represented as ``, i.e. a package import with no components.
158    fn default() -> Self {
159        Self {
160            origin: PathOrigin::Package,
161            components: Vec::new(),
162        }
163    }
164}
165
166#[cfg(feature = "imports")]
167impl<T: AsRef<std::path::Path>> From<T> for ModulePath {
168    fn from(value: T) -> Self {
169        ModulePath::from_path(value.as_ref())
170    }
171}
172
173impl GlobalDeclaration {
174    /// Remove all [`Statement::Void`]
175    pub fn remove_voids(&mut self) {
176        if let GlobalDeclaration::Function(decl) = self {
177            decl.body.remove_voids();
178        }
179    }
180}
181
182impl TypeExpression {
183    /// New [`TypeExpression`] with no template.
184    pub fn new(ident: Ident) -> Self {
185        Self {
186            #[cfg(feature = "imports")]
187            path: None,
188            ident,
189            template_args: None,
190        }
191    }
192}
193
194impl CompoundStatement {
195    /// Remove all [`Statement::Void`]
196    pub fn remove_voids(&mut self) {
197        self.statements.retain_mut(|stmt| match stmt.node_mut() {
198            Statement::Void => false,
199            _ => {
200                stmt.remove_voids();
201                true
202            }
203        })
204    }
205}
206
207impl Statement {
208    /// Remove all [`Statement::Void`]
209    pub fn remove_voids(&mut self) {
210        match self {
211            Statement::Compound(stmt) => {
212                stmt.remove_voids();
213            }
214            Statement::If(stmt) => {
215                stmt.if_clause.body.remove_voids();
216                for clause in &mut stmt.else_if_clauses {
217                    clause.body.remove_voids();
218                }
219                if let Some(clause) = &mut stmt.else_clause {
220                    clause.body.remove_voids();
221                }
222            }
223            Statement::Switch(stmt) => stmt
224                .clauses
225                .iter_mut()
226                .for_each(|clause| clause.body.remove_voids()),
227            Statement::Loop(stmt) => stmt.body.remove_voids(),
228            Statement::For(stmt) => stmt.body.remove_voids(),
229            Statement::While(stmt) => stmt.body.remove_voids(),
230            _ => (),
231        }
232    }
233}
234
235impl AccessMode {
236    /// Is [`Self::Read`] or [`Self::ReadWrite`]
237    pub fn is_read(&self) -> bool {
238        matches!(self, Self::Read | Self::ReadWrite)
239    }
240    /// Is [`Self::Write`] or [`Self::ReadWrite`]
241    pub fn is_write(&self) -> bool {
242        matches!(self, Self::Write | Self::ReadWrite)
243    }
244}
245
246impl From<Ident> for TypeExpression {
247    fn from(name: Ident) -> Self {
248        Self {
249            #[cfg(feature = "imports")]
250            path: None,
251            ident: name,
252            template_args: None,
253        }
254    }
255}
256
257impl From<ExpressionNode> for ReturnStatement {
258    fn from(expression: ExpressionNode) -> Self {
259        Self {
260            #[cfg(feature = "attributes")]
261            attributes: Default::default(),
262            expression: Some(expression),
263        }
264    }
265}
266impl From<Expression> for ReturnStatement {
267    fn from(expression: Expression) -> Self {
268        Self::from(ExpressionNode::from(expression))
269    }
270}
271
272impl From<FunctionCall> for FunctionCallStatement {
273    fn from(call: FunctionCall) -> Self {
274        Self {
275            #[cfg(feature = "attributes")]
276            attributes: Default::default(),
277            call,
278        }
279    }
280}
281
282impl GlobalDeclaration {
283    /// Get the name of the declaration, if it has one.
284    pub fn ident(&self) -> Option<&Ident> {
285        match self {
286            GlobalDeclaration::Void => None,
287            GlobalDeclaration::Declaration(decl) => Some(&decl.ident),
288            GlobalDeclaration::TypeAlias(decl) => Some(&decl.ident),
289            GlobalDeclaration::Struct(decl) => Some(&decl.ident),
290            GlobalDeclaration::Function(decl) => Some(&decl.ident),
291            GlobalDeclaration::ConstAssert(_) => None,
292        }
293    }
294    /// Get the name of the declaration, if it has one.
295    pub fn ident_mut(&mut self) -> Option<&mut Ident> {
296        match self {
297            GlobalDeclaration::Void => None,
298            GlobalDeclaration::Declaration(decl) => Some(&mut decl.ident),
299            GlobalDeclaration::TypeAlias(decl) => Some(&mut decl.ident),
300            GlobalDeclaration::Struct(decl) => Some(&mut decl.ident),
301            GlobalDeclaration::Function(decl) => Some(&mut decl.ident),
302            GlobalDeclaration::ConstAssert(_) => None,
303        }
304    }
305}
306
307/// A trait implemented on all types that can be prefixed by attributes.
308pub trait Decorated {
309    /// List all attributes (`@name`) of a syntax node.
310    fn attributes(&self) -> &[Attribute];
311    /// List all attributes (`@name`) of a syntax node.
312    fn attributes_mut(&mut self) -> &mut [Attribute];
313    /// Remove attributes with predicate.
314    fn retain_attributes_mut<F>(&mut self, f: F)
315    where
316        F: FnMut(&mut Attribute) -> bool;
317}
318
319impl<T: Decorated> Decorated for Spanned<T> {
320    fn attributes(&self) -> &[Attribute] {
321        self.node().attributes()
322    }
323
324    fn attributes_mut(&mut self) -> &mut [Attribute] {
325        self.node_mut().attributes_mut()
326    }
327
328    fn retain_attributes_mut<F>(&mut self, f: F)
329    where
330        F: FnMut(&mut Attribute) -> bool,
331    {
332        self.node_mut().retain_attributes_mut(f)
333    }
334}
335
336macro_rules! impl_decorated_struct {
337    ($ty:ty) => {
338        impl Decorated for $ty {
339            fn attributes(&self) -> &[Attribute] {
340                &self.attributes
341            }
342            fn attributes_mut(&mut self) -> &mut [Attribute] {
343                &mut self.attributes
344            }
345            fn retain_attributes_mut<F>(&mut self, f: F)
346            where
347                F: FnMut(&mut Attribute) -> bool,
348            {
349                self.attributes.retain_mut(f)
350            }
351        }
352    };
353}
354
355#[cfg(all(feature = "imports", feature = "attributes"))]
356impl_decorated_struct!(ImportStatement);
357
358#[cfg(feature = "attributes")]
359impl Decorated for GlobalDirective {
360    fn attributes(&self) -> &[Attribute] {
361        match self {
362            GlobalDirective::Diagnostic(directive) => &directive.attributes,
363            GlobalDirective::Enable(directive) => &directive.attributes,
364            GlobalDirective::Requires(directive) => &directive.attributes,
365        }
366    }
367
368    fn attributes_mut(&mut self) -> &mut [Attribute] {
369        match self {
370            GlobalDirective::Diagnostic(directive) => &mut directive.attributes,
371            GlobalDirective::Enable(directive) => &mut directive.attributes,
372            GlobalDirective::Requires(directive) => &mut directive.attributes,
373        }
374    }
375
376    fn retain_attributes_mut<F>(&mut self, f: F)
377    where
378        F: FnMut(&mut Attribute) -> bool,
379    {
380        match self {
381            GlobalDirective::Diagnostic(directive) => directive.attributes.retain_mut(f),
382            GlobalDirective::Enable(directive) => directive.attributes.retain_mut(f),
383            GlobalDirective::Requires(directive) => directive.attributes.retain_mut(f),
384        }
385    }
386}
387
388#[cfg(feature = "attributes")]
389impl_decorated_struct!(DiagnosticDirective);
390
391#[cfg(feature = "attributes")]
392impl_decorated_struct!(EnableDirective);
393
394#[cfg(feature = "attributes")]
395impl_decorated_struct!(RequiresDirective);
396
397#[cfg(feature = "attributes")]
398impl Decorated for GlobalDeclaration {
399    fn attributes(&self) -> &[Attribute] {
400        match self {
401            GlobalDeclaration::Void => &[],
402            GlobalDeclaration::Declaration(decl) => &decl.attributes,
403            GlobalDeclaration::TypeAlias(decl) => &decl.attributes,
404            GlobalDeclaration::Struct(decl) => &decl.attributes,
405            GlobalDeclaration::Function(decl) => &decl.attributes,
406            GlobalDeclaration::ConstAssert(decl) => &decl.attributes,
407        }
408    }
409
410    fn attributes_mut(&mut self) -> &mut [Attribute] {
411        match self {
412            GlobalDeclaration::Void => &mut [],
413            GlobalDeclaration::Declaration(decl) => &mut decl.attributes,
414            GlobalDeclaration::TypeAlias(decl) => &mut decl.attributes,
415            GlobalDeclaration::Struct(decl) => &mut decl.attributes,
416            GlobalDeclaration::Function(decl) => &mut decl.attributes,
417            GlobalDeclaration::ConstAssert(decl) => &mut decl.attributes,
418        }
419    }
420
421    fn retain_attributes_mut<F>(&mut self, f: F)
422    where
423        F: FnMut(&mut Attribute) -> bool,
424    {
425        match self {
426            GlobalDeclaration::Void => {}
427            GlobalDeclaration::Declaration(decl) => decl.attributes.retain_mut(f),
428            GlobalDeclaration::TypeAlias(decl) => decl.attributes.retain_mut(f),
429            GlobalDeclaration::Struct(decl) => decl.attributes.retain_mut(f),
430            GlobalDeclaration::Function(decl) => decl.attributes.retain_mut(f),
431            GlobalDeclaration::ConstAssert(decl) => decl.attributes.retain_mut(f),
432        }
433    }
434}
435
436impl_decorated_struct!(Declaration);
437
438#[cfg(feature = "attributes")]
439impl_decorated_struct!(TypeAlias);
440
441#[cfg(feature = "attributes")]
442impl_decorated_struct!(Struct);
443
444impl_decorated_struct!(StructMember);
445
446impl_decorated_struct!(Function);
447
448impl_decorated_struct!(FormalParameter);
449
450#[cfg(feature = "attributes")]
451impl_decorated_struct!(ConstAssert);
452
453#[cfg(feature = "attributes")]
454impl Decorated for Statement {
455    fn attributes(&self) -> &[Attribute] {
456        match self {
457            Statement::Void => &[],
458            Statement::Compound(stmt) => &stmt.attributes,
459            Statement::Assignment(stmt) => &stmt.attributes,
460            Statement::Increment(stmt) => &stmt.attributes,
461            Statement::Decrement(stmt) => &stmt.attributes,
462            Statement::If(stmt) => &stmt.attributes,
463            Statement::Switch(stmt) => &stmt.attributes,
464            Statement::Loop(stmt) => &stmt.attributes,
465            Statement::For(stmt) => &stmt.attributes,
466            Statement::While(stmt) => &stmt.attributes,
467            Statement::Break(stmt) => &stmt.attributes,
468            Statement::Continue(stmt) => &stmt.attributes,
469            Statement::Return(stmt) => &stmt.attributes,
470            Statement::Discard(stmt) => &stmt.attributes,
471            Statement::FunctionCall(stmt) => &stmt.attributes,
472            Statement::ConstAssert(stmt) => &stmt.attributes,
473            Statement::Declaration(stmt) => &stmt.attributes,
474        }
475    }
476
477    fn attributes_mut(&mut self) -> &mut [Attribute] {
478        match self {
479            Statement::Void => &mut [],
480            Statement::Compound(stmt) => &mut stmt.attributes,
481            Statement::Assignment(stmt) => &mut stmt.attributes,
482            Statement::Increment(stmt) => &mut stmt.attributes,
483            Statement::Decrement(stmt) => &mut stmt.attributes,
484            Statement::If(stmt) => &mut stmt.attributes,
485            Statement::Switch(stmt) => &mut stmt.attributes,
486            Statement::Loop(stmt) => &mut stmt.attributes,
487            Statement::For(stmt) => &mut stmt.attributes,
488            Statement::While(stmt) => &mut stmt.attributes,
489            Statement::Break(stmt) => &mut stmt.attributes,
490            Statement::Continue(stmt) => &mut stmt.attributes,
491            Statement::Return(stmt) => &mut stmt.attributes,
492            Statement::Discard(stmt) => &mut stmt.attributes,
493            Statement::FunctionCall(stmt) => &mut stmt.attributes,
494            Statement::ConstAssert(stmt) => &mut stmt.attributes,
495            Statement::Declaration(stmt) => &mut stmt.attributes,
496        }
497    }
498
499    fn retain_attributes_mut<F>(&mut self, f: F)
500    where
501        F: FnMut(&mut Attribute) -> bool,
502    {
503        match self {
504            Statement::Void => {}
505            Statement::Compound(stmt) => stmt.attributes.retain_mut(f),
506            Statement::Assignment(stmt) => stmt.attributes.retain_mut(f),
507            Statement::Increment(stmt) => stmt.attributes.retain_mut(f),
508            Statement::Decrement(stmt) => stmt.attributes.retain_mut(f),
509            Statement::If(stmt) => stmt.attributes.retain_mut(f),
510            Statement::Switch(stmt) => stmt.attributes.retain_mut(f),
511            Statement::Loop(stmt) => stmt.attributes.retain_mut(f),
512            Statement::For(stmt) => stmt.attributes.retain_mut(f),
513            Statement::While(stmt) => stmt.attributes.retain_mut(f),
514            Statement::Break(stmt) => stmt.attributes.retain_mut(f),
515            Statement::Continue(stmt) => stmt.attributes.retain_mut(f),
516            Statement::Return(stmt) => stmt.attributes.retain_mut(f),
517            Statement::Discard(stmt) => stmt.attributes.retain_mut(f),
518            Statement::FunctionCall(stmt) => stmt.attributes.retain_mut(f),
519            Statement::ConstAssert(stmt) => stmt.attributes.retain_mut(f),
520            Statement::Declaration(stmt) => stmt.attributes.retain_mut(f),
521        }
522    }
523}
524
525impl_decorated_struct!(CompoundStatement);
526
527#[cfg(feature = "attributes")]
528impl_decorated_struct!(AssignmentStatement);
529
530#[cfg(feature = "attributes")]
531impl_decorated_struct!(IncrementStatement);
532
533#[cfg(feature = "attributes")]
534impl_decorated_struct!(DecrementStatement);
535
536impl_decorated_struct!(IfStatement);
537
538#[cfg(feature = "attributes")]
539impl_decorated_struct!(ElseIfClause);
540
541#[cfg(feature = "attributes")]
542impl_decorated_struct!(ElseClause);
543
544impl_decorated_struct!(SwitchStatement);
545
546#[cfg(feature = "attributes")]
547impl_decorated_struct!(SwitchClause);
548
549impl_decorated_struct!(LoopStatement);
550
551#[cfg(feature = "attributes")]
552impl_decorated_struct!(ContinuingStatement);
553
554#[cfg(feature = "attributes")]
555impl_decorated_struct!(BreakIfStatement);
556
557impl_decorated_struct!(ForStatement);
558
559impl_decorated_struct!(WhileStatement);
560
561#[cfg(feature = "attributes")]
562impl_decorated_struct!(BreakStatement);
563
564#[cfg(feature = "attributes")]
565impl_decorated_struct!(ContinueStatement);
566
567#[cfg(feature = "attributes")]
568impl_decorated_struct!(ReturnStatement);
569
570#[cfg(feature = "attributes")]
571impl_decorated_struct!(DiscardStatement);
572
573#[cfg(feature = "attributes")]
574impl_decorated_struct!(FunctionCallStatement);