wgsl_parse/
parser_support.rs

1//! support functions to be injected in the lalrpop parser.
2
3use std::str::FromStr;
4
5use itertools::Itertools;
6
7use crate::{
8    error::ParseError,
9    span::{Span, Spanned},
10    syntax::*,
11};
12
13type E = ParseError;
14
15pub(crate) enum Component {
16    Named(Ident),
17    Index(ExpressionNode),
18}
19
20pub(crate) fn apply_components(
21    expr: Expression,
22    span: Span,
23    components: Vec<Spanned<Component>>,
24) -> Expression {
25    components.into_iter().fold(expr, |base, comp| {
26        let span = span.extend(comp.span());
27        let base = Spanned::new(base, span);
28        match comp.into_inner() {
29            Component::Named(component) => {
30                Expression::NamedComponent(NamedComponentExpression { base, component })
31            }
32            Component::Index(index) => Expression::Indexing(IndexingExpression { base, index }),
33        }
34    })
35}
36
37impl FromStr for DeclarationKind {
38    type Err = ();
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        match s {
42            "const" => Ok(Self::Const),
43            "override" => Ok(Self::Override),
44            "let" => Ok(Self::Let),
45            "var" => Ok(Self::Var(None)),
46            _ => Err(()),
47        }
48    }
49}
50
51impl FromStr for AddressSpace {
52    type Err = ();
53
54    fn from_str(s: &str) -> Result<Self, Self::Err> {
55        match s {
56            "function" => Ok(Self::Function),
57            "private" => Ok(Self::Private),
58            "workgroup" => Ok(Self::Workgroup),
59            "uniform" => Ok(Self::Uniform),
60            "storage" => Ok(Self::Storage(None)),
61            #[cfg(feature = "push_constant")]
62            "push_constant" => Ok(Self::PushConstant),
63            // "WGSL predeclares an enumerant for each address space, except for the handle address space."
64            // "handle" => Ok(Self::Handle),
65            _ => Err(()),
66        }
67    }
68}
69
70impl FromStr for AccessMode {
71    type Err = ();
72
73    fn from_str(s: &str) -> Result<Self, Self::Err> {
74        match s {
75            "read" => Ok(Self::Read),
76            "write" => Ok(Self::Write),
77            "read_write" => Ok(Self::ReadWrite),
78            _ => Err(()),
79        }
80    }
81}
82
83impl FromStr for DiagnosticSeverity {
84    type Err = ();
85
86    fn from_str(s: &str) -> Result<Self, Self::Err> {
87        match s {
88            "error" => Ok(Self::Error),
89            "warning" => Ok(Self::Warning),
90            "info" => Ok(Self::Info),
91            "off" => Ok(Self::Off),
92            _ => Err(()),
93        }
94    }
95}
96
97impl FromStr for BuiltinValue {
98    type Err = ();
99
100    fn from_str(s: &str) -> Result<Self, Self::Err> {
101        match s {
102            "vertex_index" => Ok(Self::VertexIndex),
103            "instance_index" => Ok(Self::InstanceIndex),
104            "clip_distances" => Ok(Self::ClipDistances),
105            "position" => Ok(Self::Position),
106            "front_facing" => Ok(Self::FrontFacing),
107            "frag_depth" => Ok(Self::FragDepth),
108            "sample_index" => Ok(Self::SampleIndex),
109            "sample_mask" => Ok(Self::SampleMask),
110            "local_invocation_id" => Ok(Self::LocalInvocationId),
111            "local_invocation_index" => Ok(Self::LocalInvocationIndex),
112            "global_invocation_id" => Ok(Self::GlobalInvocationId),
113            "workgroup_id" => Ok(Self::WorkgroupId),
114            "num_workgroups" => Ok(Self::NumWorkgroups),
115            "subgroup_invocation_id" => Ok(Self::SubgroupInvocationId),
116            "subgroup_size" => Ok(Self::SubgroupSize),
117            #[cfg(feature = "naga_ext")]
118            "primitive_index" => Ok(Self::PrimitiveIndex),
119            #[cfg(feature = "naga_ext")]
120            "view_index" => Ok(Self::ViewIndex),
121            _ => Err(()),
122        }
123    }
124}
125
126impl FromStr for InterpolationType {
127    type Err = ();
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        match s {
131            "perspective" => Ok(Self::Perspective),
132            "linear" => Ok(Self::Linear),
133            "flat" => Ok(Self::Flat),
134            _ => Err(()),
135        }
136    }
137}
138
139impl FromStr for InterpolationSampling {
140    type Err = ();
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        match s {
144            "center" => Ok(Self::Center),
145            "centroid" => Ok(Self::Centroid),
146            "sample" => Ok(Self::Sample),
147            "first" => Ok(Self::First),
148            "either" => Ok(Self::Either),
149            _ => Err(()),
150        }
151    }
152}
153
154#[cfg(feature = "naga_ext")]
155impl FromStr for ConservativeDepth {
156    type Err = ();
157
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        match s {
160            "greater_equal" => Ok(Self::GreaterEqual),
161            "less_equal" => Ok(Self::LessEqual),
162            "unchanged" => Ok(Self::Unchanged),
163            _ => Err(()),
164        }
165    }
166}
167
168fn one_arg(arguments: Option<Vec<ExpressionNode>>) -> Option<ExpressionNode> {
169    match arguments {
170        Some(mut args) => (args.len() == 1).then(|| args.pop().unwrap()),
171        None => None,
172    }
173}
174fn two_args(arguments: Option<Vec<ExpressionNode>>) -> Option<(ExpressionNode, ExpressionNode)> {
175    match arguments {
176        Some(args) => (args.len() == 2).then(|| args.into_iter().collect_tuple().unwrap()),
177        None => None,
178    }
179}
180fn zero_args(arguments: Option<Vec<ExpressionNode>>) -> bool {
181    arguments.is_none()
182}
183fn ident(expr: ExpressionNode) -> Option<Ident> {
184    match expr.into_inner() {
185        Expression::TypeOrIdentifier(TypeExpression {
186            #[cfg(feature = "imports")]
187                path: _,
188            ident,
189            template_args: None,
190        }) => Some(ident),
191        _ => None,
192    }
193}
194
195pub(crate) fn parse_attribute(
196    name: String,
197    args: Option<Vec<ExpressionNode>>,
198) -> Result<Attribute, E> {
199    match name.as_str() {
200        "align" => match one_arg(args) {
201            Some(expr) => Ok(Attribute::Align(expr)),
202            _ => Err(E::Attribute("align", "expected 1 argument")),
203        },
204        "binding" => match one_arg(args) {
205            Some(expr) => Ok(Attribute::Binding(expr)),
206            _ => Err(E::Attribute("binding", "expected 1 argument")),
207        },
208        "blend_src" => match one_arg(args) {
209            Some(expr) => Ok(Attribute::BlendSrc(expr)),
210            _ => Err(E::Attribute("blend_src", "expected 1 argument")),
211        },
212        "builtin" => match one_arg(args) {
213            Some(expr) => match ident(expr).and_then(|id| id.name().parse().ok()) {
214                Some(b) => Ok(Attribute::Builtin(b)),
215                _ => Err(E::Attribute(
216                    "builtin",
217                    "the argument is not a valid built-in value name",
218                )),
219            },
220            _ => Err(E::Attribute("builtin", "expected 1 argument")),
221        },
222        "const" => match zero_args(args) {
223            true => Ok(Attribute::Const),
224            false => Err(E::Attribute("const", "expected 0 arguments")),
225        },
226        "diagnostic" => match two_args(args) {
227            Some((e1, e2)) => {
228                let severity = ident(e1).and_then(|id| id.name().parse().ok());
229                let rule = match e2.into_inner() {
230                    Expression::TypeOrIdentifier(TypeExpression {
231                        #[cfg(feature = "imports")]
232                            path: _,
233                        ident,
234                        template_args: None,
235                    }) => Some(ident.name().to_string()),
236                    Expression::NamedComponent(e) => {
237                        ident(e.base).map(|id| format!("{}.{}", id.name(), e.component))
238                    }
239                    _ => None,
240                };
241                match (severity, rule) {
242                    (Some(severity), Some(rule)) => {
243                        Ok(Attribute::Diagnostic(DiagnosticAttribute {
244                            severity,
245                            rule,
246                        }))
247                    }
248                    _ => Err(E::Attribute("diagnostic", "invalid arguments")),
249                }
250            }
251            _ => Err(E::Attribute("diagnostic", "expected 1 argument")),
252        },
253        "group" => match one_arg(args) {
254            Some(expr) => Ok(Attribute::Group(expr)),
255            _ => Err(E::Attribute("group", "expected 1 argument")),
256        },
257        "id" => match one_arg(args) {
258            Some(expr) => Ok(Attribute::Id(expr)),
259            _ => Err(E::Attribute("id", "expected 1 argument")),
260        },
261        "interpolate" => match args {
262            Some(v) if v.len() == 2 => {
263                let (e1, e2) = v.into_iter().collect_tuple().unwrap();
264                let ty = ident(e1).and_then(|id| id.name().parse().ok());
265                let sampling = ident(e2).and_then(|id| id.name().parse().ok());
266                match (ty, sampling) {
267                    (Some(ty), Some(sampling)) => {
268                        Ok(Attribute::Interpolate(InterpolateAttribute {
269                            ty,
270                            sampling: Some(sampling),
271                        }))
272                    }
273                    _ => Err(E::Attribute("interpolate", "invalid arguments")),
274                }
275            }
276            Some(v) if v.len() == 1 => {
277                let e1 = v.into_iter().next().unwrap();
278                let ty = ident(e1).and_then(|id| id.name().parse().ok());
279                match ty {
280                    Some(ty) => Ok(Attribute::Interpolate(InterpolateAttribute {
281                        ty,
282                        sampling: None,
283                    })),
284                    _ => Err(E::Attribute("interpolate", "invalid arguments")),
285                }
286            }
287            _ => Err(E::Attribute("interpolate", "invalid arguments")),
288        },
289
290        "invariant" => match zero_args(args) {
291            true => Ok(Attribute::Invariant),
292            false => Err(E::Attribute("invariant", "expected 0 arguments")),
293        },
294        "location" => match one_arg(args) {
295            Some(expr) => Ok(Attribute::Location(expr)),
296            _ => Err(E::Attribute("location", "expected 1 argument")),
297        },
298        "must_use" => match zero_args(args) {
299            true => Ok(Attribute::MustUse),
300            false => Err(E::Attribute("must_use", "expected 0 arguments")),
301        },
302        "size" => match one_arg(args) {
303            Some(expr) => Ok(Attribute::Size(expr)),
304            _ => Err(E::Attribute("size", "expected 1 argument")),
305        },
306        "workgroup_size" => match args {
307            Some(args) => {
308                let mut it = args.into_iter();
309                match (it.next(), it.next(), it.next(), it.next()) {
310                    (Some(x), y, z, None) => {
311                        Ok(Attribute::WorkgroupSize(WorkgroupSizeAttribute { x, y, z }))
312                    }
313                    _ => Err(E::Attribute("workgroup_size", "expected 1-3 arguments")),
314                }
315            }
316            _ => Err(E::Attribute("workgroup_size", "expected 1-3 arguments")),
317        },
318        "vertex" => match zero_args(args) {
319            true => Ok(Attribute::Vertex),
320            false => Err(E::Attribute("vertex", "expected 0 arguments")),
321        },
322        "fragment" => match zero_args(args) {
323            true => Ok(Attribute::Fragment),
324            false => Err(E::Attribute("fragment", "expected 0 arguments")),
325        },
326        "compute" => match zero_args(args) {
327            true => Ok(Attribute::Compute),
328            false => Err(E::Attribute("compute", "expected 0 arguments")),
329        },
330        #[cfg(feature = "imports")]
331        "publish" => Ok(Attribute::Publish),
332        #[cfg(feature = "condcomp")]
333        "if" => match one_arg(args) {
334            Some(expr) => Ok(Attribute::If(expr)),
335            None => Err(E::Attribute("if", "expected 1 argument")),
336        },
337        #[cfg(feature = "condcomp")]
338        "elif" => match one_arg(args) {
339            Some(expr) => Ok(Attribute::Elif(expr)),
340            None => Err(E::Attribute("elif", "expected 1 argument")),
341        },
342        #[cfg(feature = "condcomp")]
343        "else" => match zero_args(args) {
344            true => Ok(Attribute::Else),
345            false => Err(E::Attribute("else", "expected 0 arguments")),
346        },
347        #[cfg(feature = "generics")]
348        "type" => parse_attr_type(args).map(Attribute::Type),
349        #[cfg(feature = "naga_ext")]
350        "early_depth_test" => match args {
351            Some(args) => {
352                let mut it = args.into_iter();
353                match (it.next(), it.next()) {
354                    (Some(expr), None) => match ident(expr).and_then(|id| id.name().parse().ok()) {
355                        Some(c) => Ok(Attribute::EarlyDepthTest(Some(c))),
356                        _ => Err(E::Attribute(
357                            "early_depth_test",
358                            "the argument must be one of `greater_equal`, `less_equal`, `unchanged`",
359                        )),
360                    },
361                    (None, None) => Ok(Attribute::EarlyDepthTest(None)),
362                    _ => Err(E::Attribute(
363                        "early_depth_test",
364                        "expected 0 or 1 arguments",
365                    )),
366                }
367            }
368            _ => Err(E::Attribute(
369                "early_depth_test",
370                "expected 0 or 1 arguments",
371            )),
372        },
373        _ => Ok(Attribute::Custom(CustomAttribute {
374            name,
375            arguments: args,
376        })),
377    }
378}
379
380// format: @type(T, foo | bar | baz)
381#[cfg(feature = "generics")]
382fn parse_attr_type(arguments: Option<Vec<ExpressionNode>>) -> Result<TypeConstraint, E> {
383    fn parse_rec(expr: Expression) -> Result<Vec<TypeExpression>, E> {
384        match expr {
385            Expression::TypeOrIdentifier(ty) => Ok(vec![ty]),
386            Expression::Binary(BinaryExpression {
387                operator: BinaryOperator::BitwiseOr,
388                left,
389                right,
390            }) => {
391                let ty = match right.into_inner() {
392                    Expression::TypeOrIdentifier(ty) => Ok(ty),
393                    _ => Err(E::Attribute(
394                        "type",
395                        "invalid second argument (type constraint)",
396                    )),
397                }?;
398                let mut v = parse_rec(left.into_inner())?;
399                v.push(ty);
400                Ok(v)
401            }
402            _ => Err(E::Attribute(
403                "type",
404                "invalid second argument (type constraint)",
405            )),
406        }
407    }
408    match two_args(arguments) {
409        Some((e1, e2)) => ident(e1)
410            .map(|ident| {
411                parse_rec(e2.into_inner()).map(|variants| TypeConstraint { ident, variants })
412            })
413            .unwrap_or_else(|| Err(E::Attribute("type", "invalid first argument (type name)"))),
414
415        None => Err(E::Attribute("type", "expected 2 arguments")),
416    }
417}
418
419pub(crate) fn parse_var_template(template_args: TemplateArgs) -> Result<Option<AddressSpace>, E> {
420    match template_args {
421        Some(tplt) => {
422            let mut it = tplt.into_iter();
423            match (it.next(), it.next(), it.next()) {
424                (Some(e1), e2, None) => {
425                    let mut addr_space = ident(e1.expression)
426                        .and_then(|id| id.name().parse().ok())
427                        .ok_or(E::VarTemplate("invalid address space"))?;
428                    if let Some(e2) = e2 {
429                        if let AddressSpace::Storage(access_mode) = &mut addr_space {
430                            *access_mode = Some(
431                                ident(e2.expression)
432                                    .and_then(|id| id.name().parse().ok())
433                                    .ok_or(E::VarTemplate("invalid access mode"))?,
434                            );
435                        } else {
436                            return Err(E::VarTemplate(
437                                "only variables with `storage` address space can have an access mode",
438                            ));
439                        }
440                    }
441                    Ok(Some(addr_space))
442                }
443                _ => Err(E::VarTemplate("template is empty")),
444            }
445        }
446        None => Ok(None),
447    }
448}