wgsl_parse/
lexer.rs

1//! Prefer using [`crate::parse_str`]. You shouldn't need to manipulate the lexer.
2
3use crate::error::ParseError;
4use itertools::Itertools;
5use logos::{Logos, SpannedIter};
6use std::{fmt::Display, num::NonZeroU8, sync::LazyLock};
7
8type Span = std::ops::Range<usize>;
9
10fn maybe_template_end(
11    lex: &mut logos::Lexer<Token>,
12    current: Token,
13    lookahead: Option<Token>,
14) -> Token {
15    if let Some(depth) = lex.extras.template_depths.last() {
16        // if found a ">" on the same nesting level as the opening "<", it is a template end.
17        if lex.extras.depth == *depth {
18            lex.extras.template_depths.pop();
19            // if lookahead is GreaterThan, we may have a second closing template.
20            // note that >>= can never be (TemplateEnd, TemplateEnd, Equal).
21            if let Some(depth) = lex.extras.template_depths.last() {
22                if lex.extras.depth == *depth && lookahead == Some(Token::SymGreaterThan) {
23                    lex.extras.template_depths.pop();
24                    lex.extras.lookahead = Some(Token::TemplateArgsEnd);
25                } else {
26                    lex.extras.lookahead = lookahead;
27                }
28            } else {
29                lex.extras.lookahead = lookahead;
30            }
31            return Token::TemplateArgsEnd;
32        }
33    }
34
35    current
36}
37
38// operators && and || have lower precedence than < and >.
39// therefore, this is not a template: a < b || c > d
40fn maybe_fail_template(lex: &mut logos::Lexer<Token>) -> bool {
41    if let Some(depth) = lex.extras.template_depths.last() {
42        if lex.extras.depth == *depth {
43            return false;
44        }
45    }
46    true
47}
48
49fn incr_depth(lex: &mut logos::Lexer<Token>) {
50    lex.extras.depth += 1;
51}
52
53fn decr_depth(lex: &mut logos::Lexer<Token>) {
54    lex.extras.depth -= 1;
55}
56
57// TODO: get rid of crate `lexical`
58
59// don't have to be super strict, the lexer regex already did the heavy lifting
60const DEC_FORMAT: u128 = lexical::NumberFormatBuilder::new().build();
61
62// don't have to be super strict, the lexer regex already did the heavy lifting
63const HEX_FORMAT: u128 = lexical::NumberFormatBuilder::new()
64    .mantissa_radix(16)
65    .base_prefix(NonZeroU8::new(b'x'))
66    .exponent_base(NonZeroU8::new(16))
67    .exponent_radix(NonZeroU8::new(10))
68    .build();
69
70static FLOAT_HEX_OPTIONS: LazyLock<lexical::parse_float_options::Options> = LazyLock::new(|| {
71    lexical::parse_float_options::OptionsBuilder::new()
72        .exponent(b'p')
73        .decimal_point(b'.')
74        .build()
75        .unwrap()
76});
77
78fn parse_dec_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
79    let options = &lexical::parse_integer_options::STANDARD;
80    let str = lex.slice();
81    lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
82}
83
84fn parse_hex_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
85    let options = &lexical::parse_integer_options::STANDARD;
86    let str = lex.slice();
87    lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
88}
89
90fn parse_dec_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
91    let options = &lexical::parse_integer_options::STANDARD;
92    let str = lex.slice();
93    let str = &str[..str.len() - 1];
94    lexical::parse_with_options::<i32, _, DEC_FORMAT>(str, options).ok()
95}
96
97fn parse_hex_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
98    let options = &lexical::parse_integer_options::STANDARD;
99    let str = lex.slice();
100    let str = &str[..str.len() - 1];
101    lexical::parse_with_options::<i32, _, HEX_FORMAT>(str, options).ok()
102}
103
104fn parse_dec_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
105    let options = &lexical::parse_integer_options::STANDARD;
106    let str = lex.slice();
107    let str = &str[..str.len() - 1];
108    lexical::parse_with_options::<u32, _, DEC_FORMAT>(str, options).ok()
109}
110
111fn parse_hex_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
112    let options = &lexical::parse_integer_options::STANDARD;
113    let str = lex.slice();
114    let str = &str[..str.len() - 1];
115    lexical::parse_with_options::<u32, _, HEX_FORMAT>(str, options).ok()
116}
117
118fn parse_dec_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
119    let options = &lexical::parse_float_options::STANDARD;
120    let str = lex.slice();
121    lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
122}
123
124fn parse_hex_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
125    let str = lex.slice();
126    lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
127}
128
129fn parse_dec_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
130    let options = &lexical::parse_float_options::STANDARD;
131    let str = lex.slice();
132    let str = &str[..str.len() - 1];
133    lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
134}
135
136fn parse_hex_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
137    let str = lex.slice();
138    // TODO
139    let options = &lexical::parse_float_options::STANDARD;
140    let str = &str[..str.len() - 1];
141    lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, options).ok()
142}
143
144fn parse_dec_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
145    let options = &lexical::parse_float_options::STANDARD;
146    let str = lex.slice();
147    let str = &str[..str.len() - 1];
148    lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
149}
150
151fn parse_hex_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
152    let str = lex.slice();
153    let str = &str[..str.len() - 1];
154    lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
155}
156
157#[cfg(feature = "naga_ext")]
158fn parse_dec_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
159    let options = &lexical::parse_integer_options::STANDARD;
160    let str = lex.slice();
161    let str = &str[..str.len() - 2];
162    lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
163}
164
165#[cfg(feature = "naga_ext")]
166fn parse_hex_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
167    let options = &lexical::parse_integer_options::STANDARD;
168    let str = lex.slice();
169    let str = &str[..str.len() - 2];
170    lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
171}
172
173#[cfg(feature = "naga_ext")]
174fn parse_dec_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
175    let options = &lexical::parse_integer_options::STANDARD;
176    let str = lex.slice();
177    let str = &str[..str.len() - 2];
178    lexical::parse_with_options::<u64, _, DEC_FORMAT>(str, options).ok()
179}
180
181#[cfg(feature = "naga_ext")]
182fn parse_hex_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
183    let options = &lexical::parse_integer_options::STANDARD;
184    let str = lex.slice();
185    let str = &str[..str.len() - 2];
186    lexical::parse_with_options::<u64, _, HEX_FORMAT>(str, options).ok()
187}
188
189#[cfg(feature = "naga_ext")]
190fn parse_dec_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
191    let options = &lexical::parse_float_options::STANDARD;
192    let str = lex.slice();
193    let str = &str[..str.len() - 2];
194    lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
195}
196
197#[cfg(feature = "naga_ext")]
198fn parse_hex_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
199    let str = lex.slice();
200    // TODO
201    let options = &lexical::parse_float_options::STANDARD;
202    let str = &str[..str.len() - 2];
203    lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, options).ok()
204}
205
206fn parse_line_comment(lex: &mut logos::Lexer<Token>) -> logos::Skip {
207    let rem = lex.remainder();
208    // see blankspace and line breaks: https://www.w3.org/TR/WGSL/#blankspace-and-line-breaks
209    let line_end = rem
210        .char_indices()
211        .find(|(_, c)| "\n\u{000B}\u{000C}\r\u{0085}\u{2028}\u{2029}".contains(*c))
212        .map(|(i, _)| i)
213        .unwrap_or(rem.len());
214    lex.bump(line_end);
215    logos::Skip
216}
217
218fn parse_block_comment(lex: &mut logos::Lexer<Token>) -> logos::Skip {
219    let mut depth = 1;
220    while depth > 0 {
221        let rem = lex.remainder();
222        if rem.is_empty() {
223            break;
224        } else if rem.starts_with("/*") {
225            lex.bump(2);
226            depth += 1;
227        } else if rem.starts_with("*/") {
228            lex.bump(2);
229            depth -= 1;
230        } else {
231            lex.bump(1);
232        }
233    }
234    logos::Skip
235}
236
237/// These are not valid WGSL identifiers and will trigger an error if used.
238/// Exception made of identifiers used by language extensions when the corresponding
239/// feature flag is enabled, e.g. `as`, `import`, `super`, `self` for WESL imports.
240///
241/// Reference: https://www.w3.org/TR/WGSL/#reserved-words
242const RESERVED_WORDS: &[&str] = &[
243    "NULL",
244    "Self",
245    "abstract",
246    "active",
247    "alignas",
248    "alignof",
249    "as",
250    "asm",
251    "asm_fragment",
252    "async",
253    "attribute",
254    "auto",
255    "await",
256    "become",
257    #[cfg(not(feature = "naga_ext"))]
258    "binding_array",
259    "cast",
260    "catch",
261    "class",
262    "co_await",
263    "co_return",
264    "co_yield",
265    "coherent",
266    "column_major",
267    "common",
268    "compile",
269    "compile_fragment",
270    "concept",
271    "const_cast",
272    "consteval",
273    "constexpr",
274    "constinit",
275    "crate",
276    "debugger",
277    "decltype",
278    "delete",
279    "demote",
280    "demote_to_helper",
281    "do",
282    "dynamic_cast",
283    "enum",
284    "explicit",
285    "export",
286    "extends",
287    "extern",
288    "external",
289    "fallthrough",
290    "filter",
291    "final",
292    "finally",
293    "friend",
294    "from",
295    "fxgroup",
296    "get",
297    "goto",
298    "groupshared",
299    "highp",
300    "impl",
301    "implements",
302    "import",
303    "inline",
304    "instanceof",
305    "interface",
306    "layout",
307    "lowp",
308    "macro",
309    "macro_rules",
310    "match",
311    "mediump",
312    "meta",
313    "mod",
314    "module",
315    "move",
316    "mut",
317    "mutable",
318    "namespace",
319    "new",
320    "nil",
321    "noexcept",
322    "noinline",
323    "nointerpolation",
324    "non_coherent",
325    "noncoherent",
326    "noperspective",
327    "null",
328    "nullptr",
329    "of",
330    "operator",
331    "package",
332    "packoffset",
333    "partition",
334    "pass",
335    "patch",
336    "pixelfragment",
337    "precise",
338    "precision",
339    "premerge",
340    "priv",
341    "protected",
342    "pub",
343    "public",
344    "readonly",
345    "ref",
346    "regardless",
347    "register",
348    "reinterpret_cast",
349    "require",
350    "resource",
351    "restrict",
352    "self",
353    "set",
354    "shared",
355    "sizeof",
356    "smooth",
357    "snorm",
358    "static",
359    "static_assert",
360    "static_cast",
361    "std",
362    "subroutine",
363    "super",
364    "target",
365    "template",
366    "this",
367    "thread_local",
368    "throw",
369    "trait",
370    "try",
371    "type",
372    "typedef",
373    "typeid",
374    "typename",
375    "typeof",
376    "union",
377    "unless",
378    "unorm",
379    "unsafe",
380    "unsized",
381    "use",
382    "using",
383    "varying",
384    "virtual",
385    "volatile",
386    "wgsl",
387    "where",
388    "with",
389    "writeonly",
390    "yield",
391];
392
393fn parse_ident(lex: &mut logos::Lexer<Token>) -> Token {
394    let ident = lex.slice().to_string();
395    if RESERVED_WORDS.iter().contains(&ident.as_str()) {
396        Token::ReservedWord(ident)
397    } else {
398        Token::Ident(ident)
399    }
400}
401
402#[derive(Default, Clone, Debug, PartialEq)]
403pub struct LexerState {
404    depth: i32,
405    template_depths: Vec<i32>,
406    lookahead: Option<Token>,
407}
408
409// following the spec at this date: https://www.w3.org/TR/2024/WD-WGSL-20240731/
410#[derive(Logos, Clone, Debug, PartialEq)]
411#[logos(
412    // see blankspace and line breaks: https://www.w3.org/TR/WGSL/#blankspace-and-line-breaks
413    skip r"[\s\u0085\u200e\u200f\u2028\u2029]+", // blankspace
414    extras = LexerState,
415    error = ParseError)]
416pub enum Token {
417    #[token("//", parse_line_comment)]
418    #[token("/*", parse_block_comment, priority = 2)]
419    // the parse_ident function can return either Token::Ident or Token::ReservedWord.
420    #[regex(
421        r#"([_\p{XID_Start}][\p{XID_Continue}]+)|([\p{XID_Start}])"#,
422        parse_ident,
423        priority = 1
424    )]
425    // Token::Ignored variant is never produced.
426    // It serves as a placeholder for above logos callbacks.
427    Ignored,
428    // syntactic tokens
429    // https://www.w3.org/TR/WGSL/#syntactic-tokens
430    #[token("&")]
431    SymAnd,
432    #[token("&&", maybe_fail_template)]
433    SymAndAnd,
434    #[token("->")]
435    SymArrow,
436    #[token("@")]
437    SymAttr,
438    #[token("/")]
439    SymForwardSlash,
440    #[token("!")]
441    SymBang,
442    #[token("[", incr_depth)]
443    SymBracketLeft,
444    #[token("]", decr_depth)]
445    SymBracketRight,
446    #[token("{")]
447    SymBraceLeft,
448    #[token("}")]
449    SymBraceRight,
450    #[token(":")]
451    SymColon,
452    #[token(",")]
453    SymComma,
454    #[token("=")]
455    SymEqual,
456    #[token("==")]
457    SymEqualEqual,
458    #[token("!=")]
459    SymNotEqual,
460    #[token(">", |lex| maybe_template_end(lex, Token::SymGreaterThan, None))]
461    SymGreaterThan,
462    #[token(">=", |lex| maybe_template_end(lex, Token::SymGreaterThanEqual, Some(Token::SymEqual)))]
463    SymGreaterThanEqual,
464    #[token(">>", |lex| maybe_template_end(lex, Token::SymShiftRight, Some(Token::SymGreaterThan)))]
465    SymShiftRight,
466    #[token("<")]
467    SymLessThan,
468    #[token("<=")]
469    SymLessThanEqual,
470    #[token("<<")]
471    SymShiftLeft,
472    #[token("%")]
473    SymModulo,
474    #[token("-")]
475    SymMinus,
476    #[token("--")]
477    SymMinusMinus,
478    #[token(".")]
479    SymPeriod,
480    #[token("+")]
481    SymPlus,
482    #[token("++")]
483    SymPlusPlus,
484    #[token("|")]
485    SymOr,
486    #[token("||", maybe_fail_template)]
487    SymOrOr,
488    #[token("(", incr_depth)]
489    SymParenLeft,
490    #[token(")", decr_depth)]
491    SymParenRight,
492    #[token(";")]
493    SymSemicolon,
494    #[token("*")]
495    SymStar,
496    #[token("~")]
497    SymTilde,
498    #[token("_")]
499    SymUnderscore,
500    #[token("^")]
501    SymXor,
502    #[token("+=")]
503    SymPlusEqual,
504    #[token("-=")]
505    SymMinusEqual,
506    #[token("*=")]
507    SymTimesEqual,
508    #[token("/=")]
509    SymDivisionEqual,
510    #[token("%=")]
511    SymModuloEqual,
512    #[token("&=")]
513    SymAndEqual,
514    #[token("|=")]
515    SymOrEqual,
516    #[token("^=")]
517    SymXorEqual,
518    #[token(">>=", |lex| maybe_template_end(lex, Token::SymShiftRightAssign, Some(Token::SymGreaterThanEqual)))]
519    SymShiftRightAssign,
520    #[token("<<=")]
521    SymShiftLeftAssign,
522
523    // keywords
524    // https://www.w3.org/TR/WGSL/#keyword-summary
525    #[token("alias")]
526    KwAlias,
527    #[token("break")]
528    KwBreak,
529    #[token("case")]
530    KwCase,
531    #[token("const", priority = 2)]
532    KwConst,
533    #[token("const_assert")]
534    KwConstAssert,
535    #[token("continue")]
536    KwContinue,
537    #[token("continuing")]
538    KwContinuing,
539    #[token("default")]
540    KwDefault,
541    #[token("diagnostic")]
542    KwDiagnostic,
543    #[token("discard")]
544    KwDiscard,
545    #[token("else")]
546    KwElse,
547    #[token("enable")]
548    KwEnable,
549    #[token("false")]
550    KwFalse,
551    #[token("fn")]
552    KwFn,
553    #[token("for")]
554    KwFor,
555    #[token("if")]
556    KwIf,
557    #[token("let")]
558    KwLet,
559    #[token("loop")]
560    KwLoop,
561    #[token("override")]
562    KwOverride,
563    #[token("requires")]
564    KwRequires,
565    #[token("return")]
566    KwReturn,
567    #[token("struct")]
568    KwStruct,
569    #[token("switch")]
570    KwSwitch,
571    #[token("true")]
572    KwTrue,
573    #[token("var")]
574    KwVar,
575    #[token("while")]
576    KwWhile,
577
578    // Idents and ReservedWord tokens are parsed on the Ignored variant, because of a current
579    // limitation of logos. See logos#295.
580    Ident(String),
581    // variant produced by parse_ident for reserved words.
582    // Reserved words can be used in context-dependent words, e.g. attribute names and module names.
583    ReservedWord(String),
584
585    #[regex(r#"0|[1-9]\d*"#, parse_dec_abstract_int)]
586    #[regex(r#"0[xX][\da-fA-F]+"#, parse_hex_abstract_int)]
587    AbstractInt(i64),
588    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?"#, parse_dec_abs_float)]
589    #[regex(r#"\d+[eE][+-]?\d+"#, parse_dec_abs_float)]
590    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*([pP][+-]?\d+)?"#, parse_hex_abs_float)]
591    #[regex(r#"0[xX]\.[\da-fA-F]+([pP][+-]?\d+)?"#, parse_hex_abs_float)]
592    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+"#, parse_hex_abs_float)]
593    // hex
594    AbstractFloat(f64),
595    #[regex(r#"(0|[1-9]\d*)i"#, parse_dec_i32)]
596    #[regex(r#"0[xX][\da-fA-F]+i"#, parse_hex_i32)]
597    // hex
598    I32(i32),
599    #[regex(r#"(0|[1-9]\d*)u"#, parse_dec_u32)]
600    #[regex(r#"0[xX][\da-fA-F]+u"#, parse_hex_u32)]
601    // hex
602    U32(u32),
603    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?f"#, parse_dec_f32)]
604    #[regex(r#"\d+([eE][+-]?\d+)?f"#, parse_dec_f32)]
605    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+f"#, parse_hex_f32)]
606    #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
607    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
608    F32(f32),
609    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?h"#, parse_dec_f16)]
610    #[regex(r#"\d+([eE][+-]?\d+)?h"#, parse_dec_f16)]
611    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+h"#, parse_hex_f16)]
612    #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
613    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
614    F16(f32),
615    #[cfg(feature = "naga_ext")]
616    #[regex(r#"(0|[1-9]\d*)li"#, parse_dec_i64)]
617    #[regex(r#"0[xX][\da-fA-F]+li"#, parse_hex_i64)]
618    // hex
619    I64(i64),
620    #[cfg(feature = "naga_ext")]
621    #[regex(r#"(0|[1-9]\d*)lu"#, parse_dec_u64)]
622    #[regex(r#"0[xX][\da-fA-F]+lu"#, parse_hex_u64)]
623    // hex
624    U64(u64),
625    #[cfg(feature = "naga_ext")]
626    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?lf"#, parse_dec_f64)]
627    #[regex(r#"\d+([eE][+-]?\d+)?lf"#, parse_dec_f64)]
628    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+lf"#, parse_hex_f64)]
629    #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
630    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
631    F64(f64),
632    TemplateArgsStart,
633    TemplateArgsEnd,
634
635    // extension: wesl-imports
636    // https://github.com/wgsl-tooling-wg/wesl-spec/blob/imports-update/Imports.md
637    // date: 2025-01-18, hash: 2db8e7f681087db6bdcd4a254963deb5c0159775
638    #[cfg(feature = "imports")]
639    #[token("::")]
640    SymColonColon,
641    #[cfg(feature = "imports")]
642    #[token("self")]
643    KwSelf,
644    #[cfg(feature = "imports")]
645    #[token("super")]
646    KwSuper,
647    #[cfg(feature = "imports")]
648    #[token("package")]
649    KwPackage,
650    #[cfg(feature = "imports")]
651    #[token("as")]
652    KwAs,
653    #[cfg(feature = "imports")]
654    #[token("import")]
655    KwImport,
656}
657
658impl Token {
659    #[allow(unused)]
660    pub fn is_symbol(&self) -> bool {
661        matches!(
662            self,
663            Token::SymAnd
664                | Token::SymAndAnd
665                | Token::SymArrow
666                | Token::SymAttr
667                | Token::SymForwardSlash
668                | Token::SymBang
669                | Token::SymBracketLeft
670                | Token::SymBracketRight
671                | Token::SymBraceLeft
672                | Token::SymBraceRight
673                | Token::SymColon
674                | Token::SymComma
675                | Token::SymEqual
676                | Token::SymEqualEqual
677                | Token::SymNotEqual
678                | Token::SymGreaterThan
679                | Token::SymGreaterThanEqual
680                | Token::SymShiftRight
681                | Token::SymLessThan
682                | Token::SymLessThanEqual
683                | Token::SymShiftLeft
684                | Token::SymModulo
685                | Token::SymMinus
686                | Token::SymMinusMinus
687                | Token::SymPeriod
688                | Token::SymPlus
689                | Token::SymPlusPlus
690                | Token::SymOr
691                | Token::SymOrOr
692                | Token::SymParenLeft
693                | Token::SymParenRight
694                | Token::SymSemicolon
695                | Token::SymStar
696                | Token::SymTilde
697                | Token::SymUnderscore
698                | Token::SymXor
699                | Token::SymPlusEqual
700                | Token::SymMinusEqual
701                | Token::SymTimesEqual
702                | Token::SymDivisionEqual
703                | Token::SymModuloEqual
704                | Token::SymAndEqual
705                | Token::SymOrEqual
706                | Token::SymXorEqual
707                | Token::SymShiftRightAssign
708                | Token::SymShiftLeftAssign
709        )
710    }
711
712    #[allow(unused)]
713    pub fn is_keyword(&self) -> bool {
714        matches!(
715            self,
716            Token::KwAlias
717                | Token::KwBreak
718                | Token::KwCase
719                | Token::KwConst
720                | Token::KwConstAssert
721                | Token::KwContinue
722                | Token::KwContinuing
723                | Token::KwDefault
724                | Token::KwDiagnostic
725                | Token::KwDiscard
726                | Token::KwElse
727                | Token::KwEnable
728                | Token::KwFalse
729                | Token::KwFn
730                | Token::KwFor
731                | Token::KwIf
732                | Token::KwLet
733                | Token::KwLoop
734                | Token::KwOverride
735                | Token::KwRequires
736                | Token::KwReturn
737                | Token::KwStruct
738                | Token::KwSwitch
739                | Token::KwTrue
740                | Token::KwVar
741                | Token::KwWhile
742        )
743    }
744
745    #[allow(unused)]
746    pub fn is_numeric_literal(&self) -> bool {
747        matches!(
748            self,
749            Token::AbstractInt(_)
750                | Token::AbstractFloat(_)
751                | Token::I32(_)
752                | Token::U32(_)
753                | Token::F32(_)
754                | Token::F16(_)
755        )
756    }
757}
758
759impl Display for Token {
760    /// This display implementation is used for error messages.
761    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
762        match self {
763            Token::Ignored => unreachable!(),
764            Token::SymAnd => f.write_str("&"),
765            Token::SymAndAnd => f.write_str("&&"),
766            Token::SymArrow => f.write_str("->"),
767            Token::SymAttr => f.write_str("@"),
768            Token::SymForwardSlash => f.write_str("/"),
769            Token::SymBang => f.write_str("!"),
770            Token::SymBracketLeft => f.write_str("["),
771            Token::SymBracketRight => f.write_str("]"),
772            Token::SymBraceLeft => f.write_str("{"),
773            Token::SymBraceRight => f.write_str("}"),
774            Token::SymColon => f.write_str(":"),
775            Token::SymComma => f.write_str(","),
776            Token::SymEqual => f.write_str("="),
777            Token::SymEqualEqual => f.write_str("=="),
778            Token::SymNotEqual => f.write_str("!="),
779            Token::SymGreaterThan => f.write_str(">"),
780            Token::SymGreaterThanEqual => f.write_str(">="),
781            Token::SymShiftRight => f.write_str(">>"),
782            Token::SymLessThan => f.write_str("<"),
783            Token::SymLessThanEqual => f.write_str("<="),
784            Token::SymShiftLeft => f.write_str("<<"),
785            Token::SymModulo => f.write_str("%"),
786            Token::SymMinus => f.write_str("-"),
787            Token::SymMinusMinus => f.write_str("--"),
788            Token::SymPeriod => f.write_str("."),
789            Token::SymPlus => f.write_str("+"),
790            Token::SymPlusPlus => f.write_str("++"),
791            Token::SymOr => f.write_str("|"),
792            Token::SymOrOr => f.write_str("||"),
793            Token::SymParenLeft => f.write_str("("),
794            Token::SymParenRight => f.write_str(")"),
795            Token::SymSemicolon => f.write_str(";"),
796            Token::SymStar => f.write_str("*"),
797            Token::SymTilde => f.write_str("~"),
798            Token::SymUnderscore => f.write_str("_"),
799            Token::SymXor => f.write_str("^"),
800            Token::SymPlusEqual => f.write_str("+="),
801            Token::SymMinusEqual => f.write_str("-="),
802            Token::SymTimesEqual => f.write_str("*="),
803            Token::SymDivisionEqual => f.write_str("/="),
804            Token::SymModuloEqual => f.write_str("%="),
805            Token::SymAndEqual => f.write_str("&="),
806            Token::SymOrEqual => f.write_str("|="),
807            Token::SymXorEqual => f.write_str("^="),
808            Token::SymShiftRightAssign => f.write_str(">>="),
809            Token::SymShiftLeftAssign => f.write_str("<<="),
810            Token::KwAlias => f.write_str("alias"),
811            Token::KwBreak => f.write_str("break"),
812            Token::KwCase => f.write_str("case"),
813            Token::KwConst => f.write_str("const"),
814            Token::KwConstAssert => f.write_str("const_assert"),
815            Token::KwContinue => f.write_str("continue"),
816            Token::KwContinuing => f.write_str("continuing"),
817            Token::KwDefault => f.write_str("default"),
818            Token::KwDiagnostic => f.write_str("diagnostic"),
819            Token::KwDiscard => f.write_str("discard"),
820            Token::KwElse => f.write_str("else"),
821            Token::KwEnable => f.write_str("enable"),
822            Token::KwFalse => f.write_str("false"),
823            Token::KwFn => f.write_str("fn"),
824            Token::KwFor => f.write_str("for"),
825            Token::KwIf => f.write_str("if"),
826            Token::KwLet => f.write_str("let"),
827            Token::KwLoop => f.write_str("loop"),
828            Token::KwOverride => f.write_str("override"),
829            Token::KwRequires => f.write_str("requires"),
830            Token::KwReturn => f.write_str("return"),
831            Token::KwStruct => f.write_str("struct"),
832            Token::KwSwitch => f.write_str("switch"),
833            Token::KwTrue => f.write_str("true"),
834            Token::KwVar => f.write_str("var"),
835            Token::KwWhile => f.write_str("while"),
836            Token::Ident(s) => write!(f, "identifier `{s}`"),
837            Token::ReservedWord(s) => write!(f, "reserved word `{s}`"),
838            Token::AbstractInt(n) => write!(f, "{n}"),
839            Token::AbstractFloat(n) => write!(f, "{n}"),
840            Token::I32(n) => write!(f, "{n}i"),
841            Token::U32(n) => write!(f, "{n}u"),
842            Token::F32(n) => write!(f, "{n}f"),
843            Token::F16(n) => write!(f, "{n}h"),
844            #[cfg(feature = "naga_ext")]
845            Token::I64(n) => write!(f, "{n}li"),
846            #[cfg(feature = "naga_ext")]
847            Token::U64(n) => write!(f, "{n}lu"),
848            #[cfg(feature = "naga_ext")]
849            Token::F64(n) => write!(f, "{n}lf"),
850            Token::TemplateArgsStart => f.write_str("start of template"),
851            Token::TemplateArgsEnd => f.write_str("end of template"),
852            #[cfg(feature = "imports")]
853            Token::SymColonColon => write!(f, "::"),
854            #[cfg(feature = "imports")]
855            Token::KwSelf => write!(f, "self"),
856            #[cfg(feature = "imports")]
857            Token::KwSuper => write!(f, "super"),
858            #[cfg(feature = "imports")]
859            Token::KwPackage => write!(f, "package"),
860            #[cfg(feature = "imports")]
861            Token::KwAs => write!(f, "as"),
862            #[cfg(feature = "imports")]
863            Token::KwImport => write!(f, "import"),
864        }
865    }
866}
867
868type Spanned<Tok, Loc, ParseError> = Result<(Loc, Tok, Loc), (Loc, ParseError, Loc)>;
869type NextToken = Option<(Result<Token, ParseError>, Span)>;
870
871#[derive(Clone)]
872pub struct Lexer<'s> {
873    source: &'s str,
874    token_stream: SpannedIter<'s, Token>,
875    next_token: NextToken,
876    recognizing_template: bool,
877    opened_templates: u32,
878}
879
880impl<'s> Lexer<'s> {
881    pub fn new(source: &'s str) -> Self {
882        let mut token_stream = Token::lexer_with_extras(source, LexerState::default()).spanned();
883        let next_token = token_stream.next();
884        Self {
885            source,
886            token_stream,
887            next_token,
888            recognizing_template: false,
889            opened_templates: 0,
890        }
891    }
892
893    fn take_two_tokens(&mut self) -> (NextToken, NextToken) {
894        let mut tok1 = self.next_token.take();
895
896        let lookahead = self.token_stream.extras.lookahead.take();
897        let tok2 = match lookahead {
898            Some(tok) => {
899                let (_, span1) = tok1.as_mut().unwrap(); // safety: lookahead implies lexer looked at a `<` token
900                let span2 = span1.start + 1..span1.end;
901                Some((Ok(tok), span2))
902            }
903            None => self.token_stream.next(),
904        };
905
906        (tok1, tok2)
907    }
908
909    fn next_token(&mut self) -> NextToken {
910        let (cur, mut next) = self.take_two_tokens();
911
912        let (cur_tok, cur_span) = match cur {
913            Some((Ok(tok), span)) => (tok, span),
914            Some((Err(e), span)) => return Some((Err(e), span)),
915            None => return None,
916        };
917
918        if let Some((Ok(next_tok), next_span)) = &mut next {
919            if (matches!(cur_tok, Token::Ident(_)) || cur_tok.is_keyword())
920                && *next_tok == Token::SymLessThan
921            {
922                let source = &self.source[next_span.start..];
923                if recognize_template_list(source) {
924                    *next_tok = Token::TemplateArgsStart;
925                    let cur_depth = self.token_stream.extras.depth;
926                    self.token_stream.extras.template_depths.push(cur_depth);
927                    self.opened_templates += 1;
928                }
929            }
930        }
931
932        // if we finished recognition of a template
933        if self.recognizing_template && cur_tok == Token::TemplateArgsEnd {
934            self.opened_templates -= 1;
935            if self.opened_templates == 0 {
936                next = None; // push eof after end of template
937            }
938        }
939
940        self.next_token = next;
941        Some((Ok(cur_tok), cur_span))
942    }
943}
944
945/// Returns `true` if the source starts with a valid template list.
946///
947/// ## Specification
948///
949/// [3.9. Template Lists](https://www.w3.org/TR/WGSL/#template-lists-sec)
950///
951/// Contrary to the specification [template list discovery algorithm], this function also
952/// checks that the template is syntactically valid (syntax: [*template_list*]).
953///
954/// [template list discovery algorigthm]: https://www.w3.org/TR/WGSL/#template-list-discovery
955/// [*template_list*]: https://www.w3.org/TR/WGSL/#syntax-template_list
956pub fn recognize_template_list(source: &str) -> bool {
957    let mut lexer = Lexer::new(source);
958    match lexer.next_token {
959        Some((Ok(ref mut t), _)) if *t == Token::SymLessThan => *t = Token::TemplateArgsStart,
960        _ => return false,
961    };
962    lexer.recognizing_template = true;
963    lexer.opened_templates = 1;
964    lexer.token_stream.extras.template_depths.push(0);
965    crate::parser::recognize_template_list(lexer).is_ok()
966}
967
968#[test]
969fn test_recognize_template() {
970    // cases from the WGSL spec
971    assert!(recognize_template_list("<i32,select(2,3,a>b)>"));
972    assert!(!recognize_template_list("<d]>"));
973    assert!(recognize_template_list("<B<<C>"));
974    assert!(recognize_template_list("<B<=C>"));
975    assert!(recognize_template_list("<(B>=C)>"));
976    assert!(recognize_template_list("<(B!=C)>"));
977    assert!(recognize_template_list("<(B==C)>"));
978    // more cases
979    assert!(recognize_template_list("<X>"));
980    assert!(recognize_template_list("<X<Y>>"));
981    assert!(recognize_template_list("<X<Y<Z>>>"));
982    assert!(!recognize_template_list(""));
983    assert!(!recognize_template_list(""));
984    assert!(!recognize_template_list("<>"));
985    assert!(!recognize_template_list("<b || c>d"));
986}
987
988pub trait TokenIterator: IntoIterator<Item = Spanned<Token, usize, ParseError>> {}
989
990impl Iterator for Lexer<'_> {
991    type Item = Spanned<Token, usize, ParseError>;
992
993    fn next(&mut self) -> Option<Self::Item> {
994        let tok = self.next_token();
995        tok.map(|(tok, span)| match tok {
996            Ok(tok) => Ok((span.start, tok, span.end)),
997            Err(err) => Err((span.start, err, span.end)),
998        })
999    }
1000}
1001
1002impl TokenIterator for Lexer<'_> {}