wgsl_parse/
parser_support.rs

1//! support functions to be injected in the lalrpop parser.
2
3use std::str::FromStr;
4
5use itertools::Itertools;
6
7use crate::{
8    error::CustomLalrError,
9    span::{Span, Spanned},
10    syntax::*,
11};
12
13type E = CustomLalrError;
14
15pub(crate) enum Component {
16    Named(Ident),
17    Index(ExpressionNode),
18}
19
20pub(crate) fn apply_components(
21    expr: Expression,
22    span: Span,
23    components: Vec<Spanned<Component>>,
24) -> Expression {
25    components.into_iter().fold(expr, |base, comp| {
26        let span = span.extend(comp.span());
27        let base = Spanned::new(base, span);
28        match comp.into_inner() {
29            Component::Named(component) => {
30                Expression::NamedComponent(NamedComponentExpression { base, component })
31            }
32            Component::Index(index) => Expression::Indexing(IndexingExpression { base, index }),
33        }
34    })
35}
36
37impl FromStr for DeclarationKind {
38    type Err = ();
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        match s {
42            "const" => Ok(Self::Const),
43            "override" => Ok(Self::Override),
44            "let" => Ok(Self::Let),
45            "var" => Ok(Self::Var(None)),
46            _ => Err(()),
47        }
48    }
49}
50
51impl FromStr for AddressSpace {
52    type Err = ();
53
54    fn from_str(s: &str) -> Result<Self, Self::Err> {
55        match s {
56            "function" => Ok(Self::Function),
57            "private" => Ok(Self::Private),
58            "workgroup" => Ok(Self::Workgroup),
59            "uniform" => Ok(Self::Uniform),
60            "storage" => Ok(Self::Storage(None)),
61            // "WGSL predeclares an enumerant for each address space, except for the handle address space."
62            // "handle" => Ok(Self::Handle),
63            _ => Err(()),
64        }
65    }
66}
67
68impl FromStr for AccessMode {
69    type Err = ();
70
71    fn from_str(s: &str) -> Result<Self, Self::Err> {
72        match s {
73            "read" => Ok(Self::Read),
74            "write" => Ok(Self::Write),
75            "read_write" => Ok(Self::ReadWrite),
76            _ => Err(()),
77        }
78    }
79}
80
81impl FromStr for DiagnosticSeverity {
82    type Err = ();
83
84    fn from_str(s: &str) -> Result<Self, Self::Err> {
85        match s {
86            "error" => Ok(Self::Error),
87            "warning" => Ok(Self::Warning),
88            "info" => Ok(Self::Info),
89            "off" => Ok(Self::Off),
90            _ => Err(()),
91        }
92    }
93}
94
95impl FromStr for BuiltinValue {
96    type Err = ();
97
98    fn from_str(s: &str) -> Result<Self, Self::Err> {
99        match s {
100            "vertex_index" => Ok(Self::VertexIndex),
101            "instance_index" => Ok(Self::InstanceIndex),
102            "position" => Ok(Self::Position),
103            "front_facing" => Ok(Self::FrontFacing),
104            "frag_depth" => Ok(Self::FragDepth),
105            "sample_index" => Ok(Self::SampleIndex),
106            "sample_mask" => Ok(Self::SampleMask),
107            "local_invocation_id" => Ok(Self::LocalInvocationId),
108            "local_invocation_index" => Ok(Self::LocalInvocationIndex),
109            "global_invocation_id" => Ok(Self::GlobalInvocationId),
110            "workgroup_id" => Ok(Self::WorkgroupId),
111            "num_workgroups" => Ok(Self::NumWorkgroups),
112            _ => Err(()),
113        }
114    }
115}
116
117impl FromStr for InterpolationType {
118    type Err = ();
119
120    fn from_str(s: &str) -> Result<Self, Self::Err> {
121        match s {
122            "perspective" => Ok(Self::Perspective),
123            "linear" => Ok(Self::Linear),
124            "flat" => Ok(Self::Flat),
125            _ => Err(()),
126        }
127    }
128}
129
130impl FromStr for InterpolationSampling {
131    type Err = ();
132
133    fn from_str(s: &str) -> Result<Self, Self::Err> {
134        match s {
135            "center" => Ok(Self::Center),
136            "centroid" => Ok(Self::Centroid),
137            "sample" => Ok(Self::Sample),
138            "first" => Ok(Self::First),
139            "either" => Ok(Self::Either),
140            _ => Err(()),
141        }
142    }
143}
144
145fn one_arg(arguments: Option<Vec<ExpressionNode>>) -> Option<ExpressionNode> {
146    match arguments {
147        Some(mut args) => (args.len() == 1).then(|| args.pop().unwrap()),
148        None => None,
149    }
150}
151fn two_args(arguments: Option<Vec<ExpressionNode>>) -> Option<(ExpressionNode, ExpressionNode)> {
152    match arguments {
153        Some(args) => (args.len() == 2).then(|| args.into_iter().collect_tuple().unwrap()),
154        None => None,
155    }
156}
157fn zero_args(arguments: Option<Vec<ExpressionNode>>) -> bool {
158    arguments.is_none()
159}
160fn ident(expr: ExpressionNode) -> Option<Ident> {
161    match expr.into_inner() {
162        Expression::TypeOrIdentifier(TypeExpression {
163            #[cfg(feature = "imports")]
164                path: _,
165            ident,
166            template_args: None,
167        }) => Some(ident),
168        _ => None,
169    }
170}
171
172pub(crate) fn parse_attribute(
173    name: String,
174    args: Option<Vec<ExpressionNode>>,
175) -> Result<Attribute, E> {
176    match name.as_str() {
177        "align" => match one_arg(args) {
178            Some(expr) => Ok(Attribute::Align(expr)),
179            _ => Err(E::Attribute("align", "expected 1 argument")),
180        },
181        "binding" => match one_arg(args) {
182            Some(expr) => Ok(Attribute::Binding(expr)),
183            _ => Err(E::Attribute("binding", "expected 1 argument")),
184        },
185        "blend_src" => match one_arg(args) {
186            Some(expr) => Ok(Attribute::BlendSrc(expr)),
187            _ => Err(E::Attribute("blend_src", "expected 1 argument")),
188        },
189        "builtin" => match one_arg(args) {
190            Some(expr) => match ident(expr).and_then(|id| id.name().parse().ok()) {
191                Some(b) => Ok(Attribute::Builtin(b)),
192                _ => Err(E::Attribute(
193                    "builtin",
194                    "the argument is not a valid built-in value name",
195                )),
196            },
197            _ => Err(E::Attribute("builtin", "expected 1 argument")),
198        },
199        "const" => match zero_args(args) {
200            true => Ok(Attribute::Const),
201            false => Err(E::Attribute("const", "expected 0 arguments")),
202        },
203        "diagnostic" => match two_args(args) {
204            Some((e1, e2)) => {
205                let severity = ident(e1).and_then(|id| id.name().parse().ok());
206                let rule = match e2.into_inner() {
207                    Expression::TypeOrIdentifier(TypeExpression {
208                        #[cfg(feature = "imports")]
209                            path: _,
210                        ident,
211                        template_args: None,
212                    }) => Some(ident.name().to_string()),
213                    Expression::NamedComponent(e) => {
214                        ident(e.base).map(|id| format!("{}.{}", id.name(), e.component))
215                    }
216                    _ => None,
217                };
218                match (severity, rule) {
219                    (Some(severity), Some(rule)) => {
220                        Ok(Attribute::Diagnostic(DiagnosticAttribute {
221                            severity,
222                            rule,
223                        }))
224                    }
225                    _ => Err(E::Attribute("diagnostic", "invalid arguments")),
226                }
227            }
228            _ => Err(E::Attribute("diagnostic", "expected 1 argument")),
229        },
230        "group" => match one_arg(args) {
231            Some(expr) => Ok(Attribute::Group(expr)),
232            _ => Err(E::Attribute("group", "expected 1 argument")),
233        },
234        "id" => match one_arg(args) {
235            Some(expr) => Ok(Attribute::Id(expr)),
236            _ => Err(E::Attribute("id", "expected 1 argument")),
237        },
238        "interpolate" => match args {
239            Some(v) if v.len() == 2 => {
240                let (e1, e2) = v.into_iter().collect_tuple().unwrap();
241                let ty = ident(e1).and_then(|id| id.name().parse().ok());
242                let sampling = ident(e2).and_then(|id| id.name().parse().ok());
243                match (ty, sampling) {
244                    (Some(ty), Some(sampling)) => {
245                        Ok(Attribute::Interpolate(InterpolateAttribute {
246                            ty,
247                            sampling: Some(sampling),
248                        }))
249                    }
250                    _ => Err(E::Attribute("interpolate", "invalid arguments")),
251                }
252            }
253            Some(v) if v.len() == 1 => {
254                let e1 = v.into_iter().next().unwrap();
255                let ty = ident(e1).and_then(|id| id.name().parse().ok());
256                match ty {
257                    Some(ty) => Ok(Attribute::Interpolate(InterpolateAttribute {
258                        ty,
259                        sampling: None,
260                    })),
261                    _ => Err(E::Attribute("interpolate", "invalid arguments")),
262                }
263            }
264            _ => Err(E::Attribute("interpolate", "invalid arguments")),
265        },
266
267        "invariant" => match zero_args(args) {
268            true => Ok(Attribute::Invariant),
269            false => Err(E::Attribute("invariant", "expected 0 arguments")),
270        },
271        "location" => match one_arg(args) {
272            Some(expr) => Ok(Attribute::Location(expr)),
273            _ => Err(E::Attribute("location", "expected 1 argument")),
274        },
275        "must_use" => match zero_args(args) {
276            true => Ok(Attribute::MustUse),
277            false => Err(E::Attribute("must_use", "expected 0 arguments")),
278        },
279        "size" => match one_arg(args) {
280            Some(expr) => Ok(Attribute::Size(expr)),
281            _ => Err(E::Attribute("size", "expected 1 argument")),
282        },
283        "workgroup_size" => match args {
284            Some(args) => {
285                let mut it = args.into_iter();
286                match (it.next(), it.next(), it.next(), it.next()) {
287                    (Some(x), y, z, None) => {
288                        Ok(Attribute::WorkgroupSize(WorkgroupSizeAttribute { x, y, z }))
289                    }
290                    _ => Err(E::Attribute("workgroup_size", "expected 1-3 arguments")),
291                }
292            }
293            _ => Err(E::Attribute("workgroup_size", "expected 1-3 arguments")),
294        },
295        "vertex" => match zero_args(args) {
296            true => Ok(Attribute::Vertex),
297            false => Err(E::Attribute("vertex", "expected 0 arguments")),
298        },
299        "fragment" => match zero_args(args) {
300            true => Ok(Attribute::Fragment),
301            false => Err(E::Attribute("fragment", "expected 0 arguments")),
302        },
303        "compute" => match zero_args(args) {
304            true => Ok(Attribute::Compute),
305            false => Err(E::Attribute("compute", "expected 0 arguments")),
306        },
307        #[cfg(feature = "condcomp")]
308        "if" => match one_arg(args) {
309            Some(expr) => Ok(Attribute::If(expr)),
310            None => Err(E::Attribute("if", "expected 1 argument")),
311        },
312        #[cfg(feature = "condcomp")]
313        "elif" => match one_arg(args) {
314            Some(expr) => Ok(Attribute::Elif(expr)),
315            None => Err(E::Attribute("elif", "expected 1 argument")),
316        },
317        #[cfg(feature = "condcomp")]
318        "else" => match zero_args(args) {
319            true => Ok(Attribute::Else),
320            false => Err(E::Attribute("else", "expected 0 arguments")),
321        },
322        #[cfg(feature = "generics")]
323        "type" => parse_attr_type(args).map(Attribute::Type),
324        _ => Ok(Attribute::Custom(CustomAttribute {
325            name,
326            arguments: args,
327        })),
328    }
329}
330
331// format: @type(T, foo | bar | baz)
332#[cfg(feature = "generics")]
333fn parse_attr_type(arguments: Option<Vec<ExpressionNode>>) -> Result<TypeConstraint, E> {
334    fn parse_rec(expr: Expression) -> Result<Vec<TypeExpression>, E> {
335        match expr {
336            Expression::TypeOrIdentifier(ty) => Ok(vec![ty]),
337            Expression::Binary(BinaryExpression {
338                operator: BinaryOperator::BitwiseOr,
339                left,
340                right,
341            }) => {
342                let ty = match right.into_inner() {
343                    Expression::TypeOrIdentifier(ty) => Ok(ty),
344                    _ => Err(E::Attribute(
345                        "type",
346                        "invalid second argument (type constraint)",
347                    )),
348                }?;
349                let mut v = parse_rec(left.into_inner())?;
350                v.push(ty);
351                Ok(v)
352            }
353            _ => Err(E::Attribute(
354                "type",
355                "invalid second argument (type constraint)",
356            )),
357        }
358    }
359    match two_args(arguments) {
360        Some((e1, e2)) => ident(e1)
361            .map(|ident| {
362                parse_rec(e2.into_inner()).map(|variants| TypeConstraint { ident, variants })
363            })
364            .unwrap_or_else(|| Err(E::Attribute("type", "invalid first argument (type name)"))),
365
366        None => Err(E::Attribute("type", "expected 2 arguments")),
367    }
368}
369
370pub(crate) fn parse_var_template(template_args: TemplateArgs) -> Result<Option<AddressSpace>, E> {
371    match template_args {
372        Some(tplt) => {
373            let mut it = tplt.into_iter();
374            match (it.next(), it.next(), it.next()) {
375                (Some(e1), e2, None) => {
376                    let mut addr_space = ident(e1.expression)
377                        .and_then(|id| id.name().parse().ok())
378                        .ok_or(E::VarTemplate("invalid address space"))?;
379                    if let Some(e2) = e2 {
380                        if let AddressSpace::Storage(access_mode) = &mut addr_space {
381                            *access_mode = Some(
382                                ident(e2.expression)
383                                    .and_then(|id| id.name().parse().ok())
384                                    .ok_or(E::VarTemplate("invalid access mode"))?,
385                            );
386                        } else {
387                            return Err(E::VarTemplate("only variables with `storage` address space can have an access mode"));
388                        }
389                    }
390                    Ok(Some(addr_space))
391                }
392                _ => Err(E::VarTemplate("template is empty")),
393            }
394        }
395        None => Ok(None),
396    }
397}