wgsl_parse/
syntax_impl.rs

1use itertools::Itertools;
2
3use crate::span::Spanned;
4
5use super::syntax::*;
6
7impl TranslationUnit {
8    /// New empty [`TranslationUnit`]
9    pub fn new() -> Self {
10        Self::default()
11    }
12
13    /// Remove all [`GlobalDeclaration::Void`] and [`Statement::Void`]
14    pub fn remove_voids(&mut self) {
15        self.global_declarations.retain_mut(|decl| match decl {
16            GlobalDeclaration::Void => false,
17            _ => {
18                decl.remove_voids();
19                true
20            }
21        })
22    }
23}
24
25#[cfg(feature = "imports")]
26impl ModulePath {
27    /// Create a new module path from components.
28    ///
29    /// Precondition: the path components must be valid WGSL identifiers.
30    pub fn new(origin: PathOrigin, components: Vec<String>) -> Self {
31        Self { origin, components }
32    }
33    /// Create a new module path from a filesystem path.
34    ///
35    /// * Paths with a root (leading `/` on Unix) produce `package::` paths.
36    /// * Relative paths (starting with `.` or `..`) produce `self::` or `super::` paths.
37    /// * The file extension is ignored.
38    /// * The path is canonicalized and to do so it does NOT follow symlinks.
39    ///
40    /// Precondition: the path components must be valid WGSL identifiers.
41    pub fn from_path(path: impl AsRef<std::path::Path>) -> Self {
42        use std::path::Component;
43        let mut origin = PathOrigin::Package;
44        let mut components = Vec::new();
45
46        for comp in path.as_ref().with_extension("").components() {
47            match comp {
48                Component::Prefix(_) => {}
49                Component::RootDir => origin = PathOrigin::Absolute,
50                Component::CurDir => {
51                    if components.is_empty() && origin.is_package() {
52                        origin = PathOrigin::Relative(0);
53                    }
54                }
55                Component::ParentDir => {
56                    if components.is_empty() {
57                        if let PathOrigin::Relative(n) = &mut origin {
58                            *n += 1;
59                        } else {
60                            origin = PathOrigin::Relative(1)
61                        }
62                    } else {
63                        components.pop();
64                    }
65                }
66                Component::Normal(comp) => components.push(comp.to_string_lossy().to_string()),
67            }
68        }
69
70        Self { origin, components }
71    }
72    /// Append a component to the path.
73    ///
74    /// Precondition: the `item` must be a valid WGSL identifier.
75    pub fn push(&mut self, item: &str) {
76        self.components.push(item.to_string());
77    }
78    /// Get the first component of the module path.
79    pub fn first(&self) -> Option<&str> {
80        self.components.first().map(String::as_str)
81    }
82    /// Get the last component of the module path.
83    pub fn last(&self) -> Option<&str> {
84        self.components.last().map(String::as_str)
85    }
86    /// Append `suffix` to the module path.
87    pub fn join(mut self, suffix: impl IntoIterator<Item = String>) -> Self {
88        self.components.extend(suffix);
89        self
90    }
91    /// Append `suffix` to the module path.
92    /// the suffix must be a relative module path.
93    pub fn join_path(&self, path: &Self) -> Option<Self> {
94        match path.origin {
95            PathOrigin::Relative(n) => {
96                let to_keep = self.components.len().min(n) - n;
97                let components = self
98                    .components
99                    .iter()
100                    .take(to_keep)
101                    .chain(&path.components)
102                    .cloned()
103                    .collect_vec();
104                let origin = match self.origin {
105                    PathOrigin::Absolute | PathOrigin::Package => {
106                        if n > self.components.len() {
107                            PathOrigin::Relative(n - self.components.len())
108                        } else {
109                            self.origin
110                        }
111                    }
112                    PathOrigin::Relative(m) => {
113                        if n > self.components.len() {
114                            PathOrigin::Relative(m + n - self.components.len())
115                        } else {
116                            self.origin
117                        }
118                    }
119                };
120                Some(Self { origin, components })
121            }
122            _ => None,
123        }
124    }
125    pub fn starts_with(&self, prefix: &Self) -> bool {
126        self.origin == prefix.origin
127            && prefix.components.len() >= self.components.len()
128            && prefix.components.iter().zip(&self.components).all_equal()
129    }
130    pub fn is_empty(&self) -> bool {
131        self.origin.is_package() && self.components.is_empty()
132    }
133}
134
135#[cfg(feature = "imports")]
136impl Default for ModulePath {
137    /// The path that is represented as ``, i.e. a package import with no components.
138    fn default() -> Self {
139        Self {
140            origin: PathOrigin::Package,
141            components: Vec::new(),
142        }
143    }
144}
145
146#[cfg(feature = "imports")]
147impl<T: AsRef<std::path::Path>> From<T> for ModulePath {
148    fn from(value: T) -> Self {
149        ModulePath::from_path(value.as_ref())
150    }
151}
152
153impl GlobalDeclaration {
154    /// Remove all [`Statement::Void`]
155    pub fn remove_voids(&mut self) {
156        if let GlobalDeclaration::Function(decl) = self {
157            decl.body.remove_voids();
158        }
159    }
160}
161
162impl TypeExpression {
163    /// New [`TypeExpression`] with no template.
164    pub fn new(ident: Ident) -> Self {
165        Self {
166            #[cfg(feature = "imports")]
167            path: None,
168            ident,
169            template_args: None,
170        }
171    }
172}
173
174impl CompoundStatement {
175    /// Remove all [`Statement::Void`]
176    pub fn remove_voids(&mut self) {
177        self.statements.retain_mut(|stmt| match stmt.node_mut() {
178            Statement::Void => false,
179            _ => {
180                stmt.remove_voids();
181                true
182            }
183        })
184    }
185}
186
187impl Statement {
188    /// Remove all [`Statement::Void`]
189    pub fn remove_voids(&mut self) {
190        match self {
191            Statement::Compound(stmt) => {
192                stmt.remove_voids();
193            }
194            Statement::If(stmt) => {
195                stmt.if_clause.body.remove_voids();
196                for clause in &mut stmt.else_if_clauses {
197                    clause.body.remove_voids();
198                }
199                if let Some(clause) = &mut stmt.else_clause {
200                    clause.body.remove_voids();
201                }
202            }
203            Statement::Switch(stmt) => stmt
204                .clauses
205                .iter_mut()
206                .for_each(|clause| clause.body.remove_voids()),
207            Statement::Loop(stmt) => stmt.body.remove_voids(),
208            Statement::For(stmt) => stmt.body.remove_voids(),
209            Statement::While(stmt) => stmt.body.remove_voids(),
210            _ => (),
211        }
212    }
213}
214
215impl AccessMode {
216    /// Is [`Self::Read`] or [`Self::ReadWrite`]
217    pub fn is_read(&self) -> bool {
218        matches!(self, Self::Read | Self::ReadWrite)
219    }
220    /// Is [`Self::Write`] or [`Self::ReadWrite`]
221    pub fn is_write(&self) -> bool {
222        matches!(self, Self::Write | Self::ReadWrite)
223    }
224}
225
226impl From<Ident> for TypeExpression {
227    fn from(name: Ident) -> Self {
228        Self {
229            #[cfg(feature = "imports")]
230            path: None,
231            ident: name,
232            template_args: None,
233        }
234    }
235}
236
237impl From<ExpressionNode> for ReturnStatement {
238    fn from(expression: ExpressionNode) -> Self {
239        Self {
240            #[cfg(feature = "attributes")]
241            attributes: Default::default(),
242            expression: Some(expression),
243        }
244    }
245}
246impl From<Expression> for ReturnStatement {
247    fn from(expression: Expression) -> Self {
248        Self::from(ExpressionNode::from(expression))
249    }
250}
251
252impl From<FunctionCall> for FunctionCallStatement {
253    fn from(call: FunctionCall) -> Self {
254        Self {
255            #[cfg(feature = "attributes")]
256            attributes: Default::default(),
257            call,
258        }
259    }
260}
261
262impl GlobalDeclaration {
263    /// Get the name of the declaration, if it has one.
264    pub fn ident(&self) -> Option<&Ident> {
265        match self {
266            GlobalDeclaration::Void => None,
267            GlobalDeclaration::Declaration(decl) => Some(&decl.ident),
268            GlobalDeclaration::TypeAlias(decl) => Some(&decl.ident),
269            GlobalDeclaration::Struct(decl) => Some(&decl.ident),
270            GlobalDeclaration::Function(decl) => Some(&decl.ident),
271            GlobalDeclaration::ConstAssert(_) => None,
272        }
273    }
274    /// Get the name of the declaration, if it has one.
275    pub fn ident_mut(&mut self) -> Option<&mut Ident> {
276        match self {
277            GlobalDeclaration::Void => None,
278            GlobalDeclaration::Declaration(decl) => Some(&mut decl.ident),
279            GlobalDeclaration::TypeAlias(decl) => Some(&mut decl.ident),
280            GlobalDeclaration::Struct(decl) => Some(&mut decl.ident),
281            GlobalDeclaration::Function(decl) => Some(&mut decl.ident),
282            GlobalDeclaration::ConstAssert(_) => None,
283        }
284    }
285}
286
287/// A trait implemented on all types that can be prefixed by attributes.
288pub trait Decorated {
289    /// List all attributes (`@name`) of a syntax node.
290    fn attributes(&self) -> &[Attribute];
291    /// List all attributes (`@name`) of a syntax node.
292    fn attributes_mut(&mut self) -> &mut [Attribute];
293    /// Remove attributes with predicate.
294    fn retain_attributes_mut<F>(&mut self, f: F)
295    where
296        F: FnMut(&mut Attribute) -> bool;
297}
298
299impl<T: Decorated> Decorated for Spanned<T> {
300    fn attributes(&self) -> &[Attribute] {
301        self.node().attributes()
302    }
303
304    fn attributes_mut(&mut self) -> &mut [Attribute] {
305        self.node_mut().attributes_mut()
306    }
307
308    fn retain_attributes_mut<F>(&mut self, f: F)
309    where
310        F: FnMut(&mut Attribute) -> bool,
311    {
312        self.node_mut().retain_attributes_mut(f)
313    }
314}
315
316macro_rules! impl_decorated_struct {
317    ($ty:ty) => {
318        impl Decorated for $ty {
319            fn attributes(&self) -> &[Attribute] {
320                &self.attributes
321            }
322            fn attributes_mut(&mut self) -> &mut [Attribute] {
323                &mut self.attributes
324            }
325            fn retain_attributes_mut<F>(&mut self, f: F)
326            where
327                F: FnMut(&mut Attribute) -> bool,
328            {
329                self.attributes.retain_mut(f)
330            }
331        }
332    };
333}
334
335#[cfg(all(feature = "imports", feature = "attributes"))]
336impl_decorated_struct!(ImportStatement);
337
338#[cfg(feature = "attributes")]
339impl Decorated for GlobalDirective {
340    fn attributes(&self) -> &[Attribute] {
341        match self {
342            GlobalDirective::Diagnostic(directive) => &directive.attributes,
343            GlobalDirective::Enable(directive) => &directive.attributes,
344            GlobalDirective::Requires(directive) => &directive.attributes,
345        }
346    }
347
348    fn attributes_mut(&mut self) -> &mut [Attribute] {
349        match self {
350            GlobalDirective::Diagnostic(directive) => &mut directive.attributes,
351            GlobalDirective::Enable(directive) => &mut directive.attributes,
352            GlobalDirective::Requires(directive) => &mut directive.attributes,
353        }
354    }
355
356    fn retain_attributes_mut<F>(&mut self, f: F)
357    where
358        F: FnMut(&mut Attribute) -> bool,
359    {
360        match self {
361            GlobalDirective::Diagnostic(directive) => directive.attributes.retain_mut(f),
362            GlobalDirective::Enable(directive) => directive.attributes.retain_mut(f),
363            GlobalDirective::Requires(directive) => directive.attributes.retain_mut(f),
364        }
365    }
366}
367
368#[cfg(feature = "attributes")]
369impl_decorated_struct!(DiagnosticDirective);
370
371#[cfg(feature = "attributes")]
372impl_decorated_struct!(EnableDirective);
373
374#[cfg(feature = "attributes")]
375impl_decorated_struct!(RequiresDirective);
376
377#[cfg(feature = "attributes")]
378impl Decorated for GlobalDeclaration {
379    fn attributes(&self) -> &[Attribute] {
380        match self {
381            GlobalDeclaration::Void => &[],
382            GlobalDeclaration::Declaration(decl) => &decl.attributes,
383            GlobalDeclaration::TypeAlias(decl) => &decl.attributes,
384            GlobalDeclaration::Struct(decl) => &decl.attributes,
385            GlobalDeclaration::Function(decl) => &decl.attributes,
386            GlobalDeclaration::ConstAssert(decl) => &decl.attributes,
387        }
388    }
389
390    fn attributes_mut(&mut self) -> &mut [Attribute] {
391        match self {
392            GlobalDeclaration::Void => &mut [],
393            GlobalDeclaration::Declaration(decl) => &mut decl.attributes,
394            GlobalDeclaration::TypeAlias(decl) => &mut decl.attributes,
395            GlobalDeclaration::Struct(decl) => &mut decl.attributes,
396            GlobalDeclaration::Function(decl) => &mut decl.attributes,
397            GlobalDeclaration::ConstAssert(decl) => &mut decl.attributes,
398        }
399    }
400
401    fn retain_attributes_mut<F>(&mut self, f: F)
402    where
403        F: FnMut(&mut Attribute) -> bool,
404    {
405        match self {
406            GlobalDeclaration::Void => {}
407            GlobalDeclaration::Declaration(decl) => decl.attributes.retain_mut(f),
408            GlobalDeclaration::TypeAlias(decl) => decl.attributes.retain_mut(f),
409            GlobalDeclaration::Struct(decl) => decl.attributes.retain_mut(f),
410            GlobalDeclaration::Function(decl) => decl.attributes.retain_mut(f),
411            GlobalDeclaration::ConstAssert(decl) => decl.attributes.retain_mut(f),
412        }
413    }
414}
415
416impl_decorated_struct!(Declaration);
417
418#[cfg(feature = "attributes")]
419impl_decorated_struct!(TypeAlias);
420
421#[cfg(feature = "attributes")]
422impl_decorated_struct!(Struct);
423
424impl_decorated_struct!(StructMember);
425
426impl_decorated_struct!(Function);
427
428impl_decorated_struct!(FormalParameter);
429
430#[cfg(feature = "attributes")]
431impl_decorated_struct!(ConstAssert);
432
433#[cfg(feature = "attributes")]
434impl Decorated for Statement {
435    fn attributes(&self) -> &[Attribute] {
436        match self {
437            Statement::Void => &[],
438            Statement::Compound(stmt) => &stmt.attributes,
439            Statement::Assignment(stmt) => &stmt.attributes,
440            Statement::Increment(stmt) => &stmt.attributes,
441            Statement::Decrement(stmt) => &stmt.attributes,
442            Statement::If(stmt) => &stmt.attributes,
443            Statement::Switch(stmt) => &stmt.attributes,
444            Statement::Loop(stmt) => &stmt.attributes,
445            Statement::For(stmt) => &stmt.attributes,
446            Statement::While(stmt) => &stmt.attributes,
447            Statement::Break(stmt) => &stmt.attributes,
448            Statement::Continue(stmt) => &stmt.attributes,
449            Statement::Return(stmt) => &stmt.attributes,
450            Statement::Discard(stmt) => &stmt.attributes,
451            Statement::FunctionCall(stmt) => &stmt.attributes,
452            Statement::ConstAssert(stmt) => &stmt.attributes,
453            Statement::Declaration(stmt) => &stmt.attributes,
454        }
455    }
456
457    fn attributes_mut(&mut self) -> &mut [Attribute] {
458        match self {
459            Statement::Void => &mut [],
460            Statement::Compound(stmt) => &mut stmt.attributes,
461            Statement::Assignment(stmt) => &mut stmt.attributes,
462            Statement::Increment(stmt) => &mut stmt.attributes,
463            Statement::Decrement(stmt) => &mut stmt.attributes,
464            Statement::If(stmt) => &mut stmt.attributes,
465            Statement::Switch(stmt) => &mut stmt.attributes,
466            Statement::Loop(stmt) => &mut stmt.attributes,
467            Statement::For(stmt) => &mut stmt.attributes,
468            Statement::While(stmt) => &mut stmt.attributes,
469            Statement::Break(stmt) => &mut stmt.attributes,
470            Statement::Continue(stmt) => &mut stmt.attributes,
471            Statement::Return(stmt) => &mut stmt.attributes,
472            Statement::Discard(stmt) => &mut stmt.attributes,
473            Statement::FunctionCall(stmt) => &mut stmt.attributes,
474            Statement::ConstAssert(stmt) => &mut stmt.attributes,
475            Statement::Declaration(stmt) => &mut stmt.attributes,
476        }
477    }
478
479    fn retain_attributes_mut<F>(&mut self, f: F)
480    where
481        F: FnMut(&mut Attribute) -> bool,
482    {
483        match self {
484            Statement::Void => {}
485            Statement::Compound(stmt) => stmt.attributes.retain_mut(f),
486            Statement::Assignment(stmt) => stmt.attributes.retain_mut(f),
487            Statement::Increment(stmt) => stmt.attributes.retain_mut(f),
488            Statement::Decrement(stmt) => stmt.attributes.retain_mut(f),
489            Statement::If(stmt) => stmt.attributes.retain_mut(f),
490            Statement::Switch(stmt) => stmt.attributes.retain_mut(f),
491            Statement::Loop(stmt) => stmt.attributes.retain_mut(f),
492            Statement::For(stmt) => stmt.attributes.retain_mut(f),
493            Statement::While(stmt) => stmt.attributes.retain_mut(f),
494            Statement::Break(stmt) => stmt.attributes.retain_mut(f),
495            Statement::Continue(stmt) => stmt.attributes.retain_mut(f),
496            Statement::Return(stmt) => stmt.attributes.retain_mut(f),
497            Statement::Discard(stmt) => stmt.attributes.retain_mut(f),
498            Statement::FunctionCall(stmt) => stmt.attributes.retain_mut(f),
499            Statement::ConstAssert(stmt) => stmt.attributes.retain_mut(f),
500            Statement::Declaration(stmt) => stmt.attributes.retain_mut(f),
501        }
502    }
503}
504
505impl_decorated_struct!(CompoundStatement);
506
507#[cfg(feature = "attributes")]
508impl_decorated_struct!(AssignmentStatement);
509
510#[cfg(feature = "attributes")]
511impl_decorated_struct!(IncrementStatement);
512
513#[cfg(feature = "attributes")]
514impl_decorated_struct!(DecrementStatement);
515
516impl_decorated_struct!(IfStatement);
517
518#[cfg(feature = "attributes")]
519impl_decorated_struct!(ElseIfClause);
520
521#[cfg(feature = "attributes")]
522impl_decorated_struct!(ElseClause);
523
524impl_decorated_struct!(SwitchStatement);
525
526#[cfg(feature = "attributes")]
527impl_decorated_struct!(SwitchClause);
528
529impl_decorated_struct!(LoopStatement);
530
531#[cfg(feature = "attributes")]
532impl_decorated_struct!(ContinuingStatement);
533
534#[cfg(feature = "attributes")]
535impl_decorated_struct!(BreakIfStatement);
536
537impl_decorated_struct!(ForStatement);
538
539impl_decorated_struct!(WhileStatement);
540
541#[cfg(feature = "attributes")]
542impl_decorated_struct!(BreakStatement);
543
544#[cfg(feature = "attributes")]
545impl_decorated_struct!(ContinueStatement);
546
547#[cfg(feature = "attributes")]
548impl_decorated_struct!(ReturnStatement);
549
550#[cfg(feature = "attributes")]
551impl_decorated_struct!(DiscardStatement);
552
553#[cfg(feature = "attributes")]
554impl_decorated_struct!(FunctionCallStatement);