1use crate::error::ParseError;
4use itertools::Itertools;
5use logos::{Logos, SpannedIter};
6use std::{fmt::Display, num::NonZeroU8, sync::LazyLock};
7
8type Span = std::ops::Range<usize>;
9
10fn maybe_template_end(
11 lex: &mut logos::Lexer<Token>,
12 current: Token,
13 lookahead: Option<Token>,
14) -> Token {
15 if let Some(depth) = lex.extras.template_depths.last() {
16 if lex.extras.depth == *depth {
18 lex.extras.template_depths.pop();
19 if let Some(depth) = lex.extras.template_depths.last() {
22 if lex.extras.depth == *depth && lookahead == Some(Token::SymGreaterThan) {
23 lex.extras.template_depths.pop();
24 lex.extras.lookahead = Some(Token::TemplateArgsEnd);
25 } else {
26 lex.extras.lookahead = lookahead;
27 }
28 } else {
29 lex.extras.lookahead = lookahead;
30 }
31 return Token::TemplateArgsEnd;
32 }
33 }
34
35 current
36}
37
38fn maybe_fail_template(lex: &mut logos::Lexer<Token>) -> bool {
41 if let Some(depth) = lex.extras.template_depths.last() {
42 if lex.extras.depth == *depth {
43 return false;
44 }
45 }
46 true
47}
48
49fn incr_depth(lex: &mut logos::Lexer<Token>) {
50 lex.extras.depth += 1;
51}
52
53fn decr_depth(lex: &mut logos::Lexer<Token>) {
54 lex.extras.depth -= 1;
55}
56
57const DEC_FORMAT: u128 = lexical::NumberFormatBuilder::new().build();
61
62const HEX_FORMAT: u128 = lexical::NumberFormatBuilder::new()
64 .mantissa_radix(16)
65 .base_prefix(NonZeroU8::new(b'x'))
66 .exponent_base(NonZeroU8::new(16))
67 .exponent_radix(NonZeroU8::new(10))
68 .build();
69
70static FLOAT_HEX_OPTIONS: LazyLock<lexical::parse_float_options::Options> = LazyLock::new(|| {
71 lexical::parse_float_options::OptionsBuilder::new()
72 .exponent(b'p')
73 .decimal_point(b'.')
74 .build()
75 .unwrap()
76});
77
78fn parse_dec_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
79 let options = &lexical::parse_integer_options::STANDARD;
80 let str = lex.slice();
81 lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
82}
83
84fn parse_hex_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
85 let options = &lexical::parse_integer_options::STANDARD;
86 let str = lex.slice();
87 lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
88}
89
90fn parse_dec_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
91 let options = &lexical::parse_integer_options::STANDARD;
92 let str = lex.slice();
93 let str = &str[..str.len() - 1];
94 lexical::parse_with_options::<i32, _, DEC_FORMAT>(str, options).ok()
95}
96
97fn parse_hex_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
98 let options = &lexical::parse_integer_options::STANDARD;
99 let str = lex.slice();
100 let str = &str[..str.len() - 1];
101 lexical::parse_with_options::<i32, _, HEX_FORMAT>(str, options).ok()
102}
103
104fn parse_dec_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
105 let options = &lexical::parse_integer_options::STANDARD;
106 let str = lex.slice();
107 let str = &str[..str.len() - 1];
108 lexical::parse_with_options::<u32, _, DEC_FORMAT>(str, options).ok()
109}
110
111fn parse_hex_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
112 let options = &lexical::parse_integer_options::STANDARD;
113 let str = lex.slice();
114 let str = &str[..str.len() - 1];
115 lexical::parse_with_options::<u32, _, HEX_FORMAT>(str, options).ok()
116}
117
118fn parse_dec_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
119 let options = &lexical::parse_float_options::STANDARD;
120 let str = lex.slice();
121 lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
122}
123
124fn parse_hex_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
125 let str = lex.slice();
126 lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
127}
128
129fn parse_dec_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
130 let options = &lexical::parse_float_options::STANDARD;
131 let str = lex.slice();
132 let str = &str[..str.len() - 1];
133 lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
134}
135
136fn parse_hex_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
137 let str = lex.slice();
138 let options = &lexical::parse_float_options::STANDARD;
140 let str = &str[..str.len() - 1];
141 lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, options).ok()
142}
143
144fn parse_dec_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
145 let options = &lexical::parse_float_options::STANDARD;
146 let str = lex.slice();
147 let str = &str[..str.len() - 1];
148 lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
149}
150
151fn parse_hex_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
152 let str = lex.slice();
153 let str = &str[..str.len() - 1];
154 lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
155}
156
157#[cfg(feature = "naga_ext")]
158fn parse_dec_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
159 let options = &lexical::parse_integer_options::STANDARD;
160 let str = lex.slice();
161 let str = &str[..str.len() - 2];
162 lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
163}
164
165#[cfg(feature = "naga_ext")]
166fn parse_hex_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
167 let options = &lexical::parse_integer_options::STANDARD;
168 let str = lex.slice();
169 let str = &str[..str.len() - 2];
170 lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
171}
172
173#[cfg(feature = "naga_ext")]
174fn parse_dec_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
175 let options = &lexical::parse_integer_options::STANDARD;
176 let str = lex.slice();
177 let str = &str[..str.len() - 2];
178 lexical::parse_with_options::<u64, _, DEC_FORMAT>(str, options).ok()
179}
180
181#[cfg(feature = "naga_ext")]
182fn parse_hex_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
183 let options = &lexical::parse_integer_options::STANDARD;
184 let str = lex.slice();
185 let str = &str[..str.len() - 2];
186 lexical::parse_with_options::<u64, _, HEX_FORMAT>(str, options).ok()
187}
188
189#[cfg(feature = "naga_ext")]
190fn parse_dec_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
191 let options = &lexical::parse_float_options::STANDARD;
192 let str = lex.slice();
193 let str = &str[..str.len() - 2];
194 lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
195}
196
197#[cfg(feature = "naga_ext")]
198fn parse_hex_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
199 let str = lex.slice();
200 let options = &lexical::parse_float_options::STANDARD;
202 let str = &str[..str.len() - 2];
203 lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, options).ok()
204}
205
206fn parse_line_comment(lex: &mut logos::Lexer<Token>) -> logos::Skip {
207 let rem = lex.remainder();
208 let line_end = rem
210 .char_indices()
211 .find(|(_, c)| "\n\u{000B}\u{000C}\r\u{0085}\u{2028}\u{2029}".contains(*c))
212 .map(|(i, _)| i)
213 .unwrap_or(rem.len());
214 lex.bump(line_end);
215 logos::Skip
216}
217
218fn parse_block_comment(lex: &mut logos::Lexer<Token>) -> logos::Skip {
219 let mut depth = 1;
220 while depth > 0 {
221 let rem = lex.remainder();
222 if rem.is_empty() {
223 break;
224 } else if rem.starts_with("/*") {
225 lex.bump(2);
226 depth += 1;
227 } else if rem.starts_with("*/") {
228 lex.bump(2);
229 depth -= 1;
230 } else {
231 lex.bump(1);
232 }
233 }
234 logos::Skip
235}
236
237const RESERVED_WORDS: &[&str] = &[
243 "NULL",
244 "Self",
245 "abstract",
246 "active",
247 "alignas",
248 "alignof",
249 "as",
250 "asm",
251 "asm_fragment",
252 "async",
253 "attribute",
254 "auto",
255 "await",
256 "become",
257 #[cfg(not(feature = "naga_ext"))]
258 "binding_array",
259 "cast",
260 "catch",
261 "class",
262 "co_await",
263 "co_return",
264 "co_yield",
265 "coherent",
266 "column_major",
267 "common",
268 "compile",
269 "compile_fragment",
270 "concept",
271 "const_cast",
272 "consteval",
273 "constexpr",
274 "constinit",
275 "crate",
276 "debugger",
277 "decltype",
278 "delete",
279 "demote",
280 "demote_to_helper",
281 "do",
282 "dynamic_cast",
283 "enum",
284 "explicit",
285 "export",
286 "extends",
287 "extern",
288 "external",
289 "fallthrough",
290 "filter",
291 "final",
292 "finally",
293 "friend",
294 "from",
295 "fxgroup",
296 "get",
297 "goto",
298 "groupshared",
299 "highp",
300 "impl",
301 "implements",
302 "import",
303 "inline",
304 "instanceof",
305 "interface",
306 "layout",
307 "lowp",
308 "macro",
309 "macro_rules",
310 "match",
311 "mediump",
312 "meta",
313 "mod",
314 "module",
315 "move",
316 "mut",
317 "mutable",
318 "namespace",
319 "new",
320 "nil",
321 "noexcept",
322 "noinline",
323 "nointerpolation",
324 "non_coherent",
325 "noncoherent",
326 "noperspective",
327 "null",
328 "nullptr",
329 "of",
330 "operator",
331 "package",
332 "packoffset",
333 "partition",
334 "pass",
335 "patch",
336 "pixelfragment",
337 "precise",
338 "precision",
339 "premerge",
340 "priv",
341 "protected",
342 "pub",
343 "public",
344 "readonly",
345 "ref",
346 "regardless",
347 "register",
348 "reinterpret_cast",
349 "require",
350 "resource",
351 "restrict",
352 "self",
353 "set",
354 "shared",
355 "sizeof",
356 "smooth",
357 "snorm",
358 "static",
359 "static_assert",
360 "static_cast",
361 "std",
362 "subroutine",
363 "super",
364 "target",
365 "template",
366 "this",
367 "thread_local",
368 "throw",
369 "trait",
370 "try",
371 "type",
372 "typedef",
373 "typeid",
374 "typename",
375 "typeof",
376 "union",
377 "unless",
378 "unorm",
379 "unsafe",
380 "unsized",
381 "use",
382 "using",
383 "varying",
384 "virtual",
385 "volatile",
386 "wgsl",
387 "where",
388 "with",
389 "writeonly",
390 "yield",
391];
392
393fn parse_ident(lex: &mut logos::Lexer<Token>) -> Token {
394 let ident = lex.slice().to_string();
395 if RESERVED_WORDS.iter().contains(&ident.as_str()) {
396 Token::ReservedWord(ident)
397 } else {
398 Token::Ident(ident)
399 }
400}
401
402#[derive(Default, Clone, Debug, PartialEq)]
403pub struct LexerState {
404 depth: i32,
405 template_depths: Vec<i32>,
406 lookahead: Option<Token>,
407}
408
409#[derive(Logos, Clone, Debug, PartialEq)]
411#[logos(
412 skip r"[\s\u0085\u200e\u200f\u2028\u2029]+", extras = LexerState,
415 error = ParseError)]
416pub enum Token {
417 #[token("//", parse_line_comment)]
418 #[token("/*", parse_block_comment, priority = 2)]
419 #[regex(
421 r#"([_\p{XID_Start}][\p{XID_Continue}]+)|([\p{XID_Start}])"#,
422 parse_ident,
423 priority = 1
424 )]
425 Ignored,
428 #[token("&")]
431 SymAnd,
432 #[token("&&", maybe_fail_template)]
433 SymAndAnd,
434 #[token("->")]
435 SymArrow,
436 #[token("@")]
437 SymAttr,
438 #[token("/")]
439 SymForwardSlash,
440 #[token("!")]
441 SymBang,
442 #[token("[", incr_depth)]
443 SymBracketLeft,
444 #[token("]", decr_depth)]
445 SymBracketRight,
446 #[token("{")]
447 SymBraceLeft,
448 #[token("}")]
449 SymBraceRight,
450 #[token(":")]
451 SymColon,
452 #[token(",")]
453 SymComma,
454 #[token("=")]
455 SymEqual,
456 #[token("==")]
457 SymEqualEqual,
458 #[token("!=")]
459 SymNotEqual,
460 #[token(">", |lex| maybe_template_end(lex, Token::SymGreaterThan, None))]
461 SymGreaterThan,
462 #[token(">=", |lex| maybe_template_end(lex, Token::SymGreaterThanEqual, Some(Token::SymEqual)))]
463 SymGreaterThanEqual,
464 #[token(">>", |lex| maybe_template_end(lex, Token::SymShiftRight, Some(Token::SymGreaterThan)))]
465 SymShiftRight,
466 #[token("<")]
467 SymLessThan,
468 #[token("<=")]
469 SymLessThanEqual,
470 #[token("<<")]
471 SymShiftLeft,
472 #[token("%")]
473 SymModulo,
474 #[token("-")]
475 SymMinus,
476 #[token("--")]
477 SymMinusMinus,
478 #[token(".")]
479 SymPeriod,
480 #[token("+")]
481 SymPlus,
482 #[token("++")]
483 SymPlusPlus,
484 #[token("|")]
485 SymOr,
486 #[token("||", maybe_fail_template)]
487 SymOrOr,
488 #[token("(", incr_depth)]
489 SymParenLeft,
490 #[token(")", decr_depth)]
491 SymParenRight,
492 #[token(";")]
493 SymSemicolon,
494 #[token("*")]
495 SymStar,
496 #[token("~")]
497 SymTilde,
498 #[token("_")]
499 SymUnderscore,
500 #[token("^")]
501 SymXor,
502 #[token("+=")]
503 SymPlusEqual,
504 #[token("-=")]
505 SymMinusEqual,
506 #[token("*=")]
507 SymTimesEqual,
508 #[token("/=")]
509 SymDivisionEqual,
510 #[token("%=")]
511 SymModuloEqual,
512 #[token("&=")]
513 SymAndEqual,
514 #[token("|=")]
515 SymOrEqual,
516 #[token("^=")]
517 SymXorEqual,
518 #[token(">>=", |lex| maybe_template_end(lex, Token::SymShiftRightAssign, Some(Token::SymGreaterThanEqual)))]
519 SymShiftRightAssign,
520 #[token("<<=")]
521 SymShiftLeftAssign,
522
523 #[token("alias")]
526 KwAlias,
527 #[token("break")]
528 KwBreak,
529 #[token("case")]
530 KwCase,
531 #[token("const", priority = 2)]
532 KwConst,
533 #[token("const_assert")]
534 KwConstAssert,
535 #[token("continue")]
536 KwContinue,
537 #[token("continuing")]
538 KwContinuing,
539 #[token("default")]
540 KwDefault,
541 #[token("diagnostic")]
542 KwDiagnostic,
543 #[token("discard")]
544 KwDiscard,
545 #[token("else")]
546 KwElse,
547 #[token("enable")]
548 KwEnable,
549 #[token("false")]
550 KwFalse,
551 #[token("fn")]
552 KwFn,
553 #[token("for")]
554 KwFor,
555 #[token("if")]
556 KwIf,
557 #[token("let")]
558 KwLet,
559 #[token("loop")]
560 KwLoop,
561 #[token("override")]
562 KwOverride,
563 #[token("requires")]
564 KwRequires,
565 #[token("return")]
566 KwReturn,
567 #[token("struct")]
568 KwStruct,
569 #[token("switch")]
570 KwSwitch,
571 #[token("true")]
572 KwTrue,
573 #[token("var")]
574 KwVar,
575 #[token("while")]
576 KwWhile,
577
578 Ident(String),
581 ReservedWord(String),
584
585 #[regex(r#"0|[1-9]\d*"#, parse_dec_abstract_int)]
586 #[regex(r#"0[xX][\da-fA-F]+"#, parse_hex_abstract_int)]
587 AbstractInt(i64),
588 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?"#, parse_dec_abs_float)]
589 #[regex(r#"\d+[eE][+-]?\d+"#, parse_dec_abs_float)]
590 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*([pP][+-]?\d+)?"#, parse_hex_abs_float)]
591 #[regex(r#"0[xX]\.[\da-fA-F]+([pP][+-]?\d+)?"#, parse_hex_abs_float)]
592 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+"#, parse_hex_abs_float)]
593 AbstractFloat(f64),
595 #[regex(r#"(0|[1-9]\d*)i"#, parse_dec_i32)]
596 #[regex(r#"0[xX][\da-fA-F]+i"#, parse_hex_i32)]
597 I32(i32),
599 #[regex(r#"(0|[1-9]\d*)u"#, parse_dec_u32)]
600 #[regex(r#"0[xX][\da-fA-F]+u"#, parse_hex_u32)]
601 U32(u32),
603 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?f"#, parse_dec_f32)]
604 #[regex(r#"\d+([eE][+-]?\d+)?f"#, parse_dec_f32)]
605 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+f"#, parse_hex_f32)]
606 #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
607 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
608 F32(f32),
609 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?h"#, parse_dec_f16)]
610 #[regex(r#"\d+([eE][+-]?\d+)?h"#, parse_dec_f16)]
611 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+h"#, parse_hex_f16)]
612 #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
613 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
614 F16(f32),
615 #[cfg(feature = "naga_ext")]
616 #[regex(r#"(0|[1-9]\d*)li"#, parse_dec_i64)]
617 #[regex(r#"0[xX][\da-fA-F]+li"#, parse_hex_i64)]
618 I64(i64),
620 #[cfg(feature = "naga_ext")]
621 #[regex(r#"(0|[1-9]\d*)lu"#, parse_dec_u64)]
622 #[regex(r#"0[xX][\da-fA-F]+lu"#, parse_hex_u64)]
623 U64(u64),
625 #[cfg(feature = "naga_ext")]
626 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?lf"#, parse_dec_f64)]
627 #[regex(r#"\d+([eE][+-]?\d+)?lf"#, parse_dec_f64)]
628 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+lf"#, parse_hex_f64)]
629 #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
630 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
631 F64(f64),
632 TemplateArgsStart,
633 TemplateArgsEnd,
634
635 #[cfg(feature = "imports")]
639 #[token("::")]
640 SymColonColon,
641 #[cfg(feature = "imports")]
642 #[token("self")]
643 KwSelf,
644 #[cfg(feature = "imports")]
645 #[token("super")]
646 KwSuper,
647 #[cfg(feature = "imports")]
648 #[token("package")]
649 KwPackage,
650 #[cfg(feature = "imports")]
651 #[token("as")]
652 KwAs,
653 #[cfg(feature = "imports")]
654 #[token("import")]
655 KwImport,
656}
657
658impl Token {
659 #[allow(unused)]
660 pub fn is_symbol(&self) -> bool {
661 matches!(
662 self,
663 Token::SymAnd
664 | Token::SymAndAnd
665 | Token::SymArrow
666 | Token::SymAttr
667 | Token::SymForwardSlash
668 | Token::SymBang
669 | Token::SymBracketLeft
670 | Token::SymBracketRight
671 | Token::SymBraceLeft
672 | Token::SymBraceRight
673 | Token::SymColon
674 | Token::SymComma
675 | Token::SymEqual
676 | Token::SymEqualEqual
677 | Token::SymNotEqual
678 | Token::SymGreaterThan
679 | Token::SymGreaterThanEqual
680 | Token::SymShiftRight
681 | Token::SymLessThan
682 | Token::SymLessThanEqual
683 | Token::SymShiftLeft
684 | Token::SymModulo
685 | Token::SymMinus
686 | Token::SymMinusMinus
687 | Token::SymPeriod
688 | Token::SymPlus
689 | Token::SymPlusPlus
690 | Token::SymOr
691 | Token::SymOrOr
692 | Token::SymParenLeft
693 | Token::SymParenRight
694 | Token::SymSemicolon
695 | Token::SymStar
696 | Token::SymTilde
697 | Token::SymUnderscore
698 | Token::SymXor
699 | Token::SymPlusEqual
700 | Token::SymMinusEqual
701 | Token::SymTimesEqual
702 | Token::SymDivisionEqual
703 | Token::SymModuloEqual
704 | Token::SymAndEqual
705 | Token::SymOrEqual
706 | Token::SymXorEqual
707 | Token::SymShiftRightAssign
708 | Token::SymShiftLeftAssign
709 )
710 }
711
712 #[allow(unused)]
713 pub fn is_keyword(&self) -> bool {
714 matches!(
715 self,
716 Token::KwAlias
717 | Token::KwBreak
718 | Token::KwCase
719 | Token::KwConst
720 | Token::KwConstAssert
721 | Token::KwContinue
722 | Token::KwContinuing
723 | Token::KwDefault
724 | Token::KwDiagnostic
725 | Token::KwDiscard
726 | Token::KwElse
727 | Token::KwEnable
728 | Token::KwFalse
729 | Token::KwFn
730 | Token::KwFor
731 | Token::KwIf
732 | Token::KwLet
733 | Token::KwLoop
734 | Token::KwOverride
735 | Token::KwRequires
736 | Token::KwReturn
737 | Token::KwStruct
738 | Token::KwSwitch
739 | Token::KwTrue
740 | Token::KwVar
741 | Token::KwWhile
742 )
743 }
744
745 #[allow(unused)]
746 pub fn is_numeric_literal(&self) -> bool {
747 matches!(
748 self,
749 Token::AbstractInt(_)
750 | Token::AbstractFloat(_)
751 | Token::I32(_)
752 | Token::U32(_)
753 | Token::F32(_)
754 | Token::F16(_)
755 )
756 }
757}
758
759impl Display for Token {
760 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
762 match self {
763 Token::Ignored => unreachable!(),
764 Token::SymAnd => f.write_str("&"),
765 Token::SymAndAnd => f.write_str("&&"),
766 Token::SymArrow => f.write_str("->"),
767 Token::SymAttr => f.write_str("@"),
768 Token::SymForwardSlash => f.write_str("/"),
769 Token::SymBang => f.write_str("!"),
770 Token::SymBracketLeft => f.write_str("["),
771 Token::SymBracketRight => f.write_str("]"),
772 Token::SymBraceLeft => f.write_str("{"),
773 Token::SymBraceRight => f.write_str("}"),
774 Token::SymColon => f.write_str(":"),
775 Token::SymComma => f.write_str(","),
776 Token::SymEqual => f.write_str("="),
777 Token::SymEqualEqual => f.write_str("=="),
778 Token::SymNotEqual => f.write_str("!="),
779 Token::SymGreaterThan => f.write_str(">"),
780 Token::SymGreaterThanEqual => f.write_str(">="),
781 Token::SymShiftRight => f.write_str(">>"),
782 Token::SymLessThan => f.write_str("<"),
783 Token::SymLessThanEqual => f.write_str("<="),
784 Token::SymShiftLeft => f.write_str("<<"),
785 Token::SymModulo => f.write_str("%"),
786 Token::SymMinus => f.write_str("-"),
787 Token::SymMinusMinus => f.write_str("--"),
788 Token::SymPeriod => f.write_str("."),
789 Token::SymPlus => f.write_str("+"),
790 Token::SymPlusPlus => f.write_str("++"),
791 Token::SymOr => f.write_str("|"),
792 Token::SymOrOr => f.write_str("||"),
793 Token::SymParenLeft => f.write_str("("),
794 Token::SymParenRight => f.write_str(")"),
795 Token::SymSemicolon => f.write_str(";"),
796 Token::SymStar => f.write_str("*"),
797 Token::SymTilde => f.write_str("~"),
798 Token::SymUnderscore => f.write_str("_"),
799 Token::SymXor => f.write_str("^"),
800 Token::SymPlusEqual => f.write_str("+="),
801 Token::SymMinusEqual => f.write_str("-="),
802 Token::SymTimesEqual => f.write_str("*="),
803 Token::SymDivisionEqual => f.write_str("/="),
804 Token::SymModuloEqual => f.write_str("%="),
805 Token::SymAndEqual => f.write_str("&="),
806 Token::SymOrEqual => f.write_str("|="),
807 Token::SymXorEqual => f.write_str("^="),
808 Token::SymShiftRightAssign => f.write_str(">>="),
809 Token::SymShiftLeftAssign => f.write_str("<<="),
810 Token::KwAlias => f.write_str("alias"),
811 Token::KwBreak => f.write_str("break"),
812 Token::KwCase => f.write_str("case"),
813 Token::KwConst => f.write_str("const"),
814 Token::KwConstAssert => f.write_str("const_assert"),
815 Token::KwContinue => f.write_str("continue"),
816 Token::KwContinuing => f.write_str("continuing"),
817 Token::KwDefault => f.write_str("default"),
818 Token::KwDiagnostic => f.write_str("diagnostic"),
819 Token::KwDiscard => f.write_str("discard"),
820 Token::KwElse => f.write_str("else"),
821 Token::KwEnable => f.write_str("enable"),
822 Token::KwFalse => f.write_str("false"),
823 Token::KwFn => f.write_str("fn"),
824 Token::KwFor => f.write_str("for"),
825 Token::KwIf => f.write_str("if"),
826 Token::KwLet => f.write_str("let"),
827 Token::KwLoop => f.write_str("loop"),
828 Token::KwOverride => f.write_str("override"),
829 Token::KwRequires => f.write_str("requires"),
830 Token::KwReturn => f.write_str("return"),
831 Token::KwStruct => f.write_str("struct"),
832 Token::KwSwitch => f.write_str("switch"),
833 Token::KwTrue => f.write_str("true"),
834 Token::KwVar => f.write_str("var"),
835 Token::KwWhile => f.write_str("while"),
836 Token::Ident(s) => write!(f, "identifier `{s}`"),
837 Token::ReservedWord(s) => write!(f, "reserved word `{s}`"),
838 Token::AbstractInt(n) => write!(f, "{n}"),
839 Token::AbstractFloat(n) => write!(f, "{n}"),
840 Token::I32(n) => write!(f, "{n}i"),
841 Token::U32(n) => write!(f, "{n}u"),
842 Token::F32(n) => write!(f, "{n}f"),
843 Token::F16(n) => write!(f, "{n}h"),
844 #[cfg(feature = "naga_ext")]
845 Token::I64(n) => write!(f, "{n}li"),
846 #[cfg(feature = "naga_ext")]
847 Token::U64(n) => write!(f, "{n}lu"),
848 #[cfg(feature = "naga_ext")]
849 Token::F64(n) => write!(f, "{n}lf"),
850 Token::TemplateArgsStart => f.write_str("start of template"),
851 Token::TemplateArgsEnd => f.write_str("end of template"),
852 #[cfg(feature = "imports")]
853 Token::SymColonColon => write!(f, "::"),
854 #[cfg(feature = "imports")]
855 Token::KwSelf => write!(f, "self"),
856 #[cfg(feature = "imports")]
857 Token::KwSuper => write!(f, "super"),
858 #[cfg(feature = "imports")]
859 Token::KwPackage => write!(f, "package"),
860 #[cfg(feature = "imports")]
861 Token::KwAs => write!(f, "as"),
862 #[cfg(feature = "imports")]
863 Token::KwImport => write!(f, "import"),
864 }
865 }
866}
867
868type Spanned<Tok, Loc, ParseError> = Result<(Loc, Tok, Loc), (Loc, ParseError, Loc)>;
869type NextToken = Option<(Result<Token, ParseError>, Span)>;
870
871#[derive(Clone)]
872pub struct Lexer<'s> {
873 source: &'s str,
874 token_stream: SpannedIter<'s, Token>,
875 next_token: NextToken,
876 recognizing_template: bool,
877 opened_templates: u32,
878}
879
880impl<'s> Lexer<'s> {
881 pub fn new(source: &'s str) -> Self {
882 let mut token_stream = Token::lexer_with_extras(source, LexerState::default()).spanned();
883 let next_token = token_stream.next();
884 Self {
885 source,
886 token_stream,
887 next_token,
888 recognizing_template: false,
889 opened_templates: 0,
890 }
891 }
892
893 fn take_two_tokens(&mut self) -> (NextToken, NextToken) {
894 let mut tok1 = self.next_token.take();
895
896 let lookahead = self.token_stream.extras.lookahead.take();
897 let tok2 = match lookahead {
898 Some(tok) => {
899 let (_, span1) = tok1.as_mut().unwrap(); let span2 = span1.start + 1..span1.end;
901 Some((Ok(tok), span2))
902 }
903 None => self.token_stream.next(),
904 };
905
906 (tok1, tok2)
907 }
908
909 fn next_token(&mut self) -> NextToken {
910 let (cur, mut next) = self.take_two_tokens();
911
912 let (cur_tok, cur_span) = match cur {
913 Some((Ok(tok), span)) => (tok, span),
914 Some((Err(e), span)) => return Some((Err(e), span)),
915 None => return None,
916 };
917
918 if let Some((Ok(next_tok), next_span)) = &mut next {
919 if (matches!(cur_tok, Token::Ident(_)) || cur_tok.is_keyword())
920 && *next_tok == Token::SymLessThan
921 {
922 let source = &self.source[next_span.start..];
923 if recognize_template_list(source) {
924 *next_tok = Token::TemplateArgsStart;
925 let cur_depth = self.token_stream.extras.depth;
926 self.token_stream.extras.template_depths.push(cur_depth);
927 self.opened_templates += 1;
928 }
929 }
930 }
931
932 if self.recognizing_template && cur_tok == Token::TemplateArgsEnd {
934 self.opened_templates -= 1;
935 if self.opened_templates == 0 {
936 next = None; }
938 }
939
940 self.next_token = next;
941 Some((Ok(cur_tok), cur_span))
942 }
943}
944
945pub fn recognize_template_list(source: &str) -> bool {
957 let mut lexer = Lexer::new(source);
958 match lexer.next_token {
959 Some((Ok(ref mut t), _)) if *t == Token::SymLessThan => *t = Token::TemplateArgsStart,
960 _ => return false,
961 };
962 lexer.recognizing_template = true;
963 lexer.opened_templates = 1;
964 lexer.token_stream.extras.template_depths.push(0);
965 crate::parser::recognize_template_list(lexer).is_ok()
966}
967
968#[test]
969fn test_recognize_template() {
970 assert!(recognize_template_list("<i32,select(2,3,a>b)>"));
972 assert!(!recognize_template_list("<d]>"));
973 assert!(recognize_template_list("<B<<C>"));
974 assert!(recognize_template_list("<B<=C>"));
975 assert!(recognize_template_list("<(B>=C)>"));
976 assert!(recognize_template_list("<(B!=C)>"));
977 assert!(recognize_template_list("<(B==C)>"));
978 assert!(recognize_template_list("<X>"));
980 assert!(recognize_template_list("<X<Y>>"));
981 assert!(recognize_template_list("<X<Y<Z>>>"));
982 assert!(!recognize_template_list(""));
983 assert!(!recognize_template_list(""));
984 assert!(!recognize_template_list("<>"));
985 assert!(!recognize_template_list("<b || c>d"));
986}
987
988pub trait TokenIterator: IntoIterator<Item = Spanned<Token, usize, ParseError>> {}
989
990impl Iterator for Lexer<'_> {
991 type Item = Spanned<Token, usize, ParseError>;
992
993 fn next(&mut self) -> Option<Self::Item> {
994 let tok = self.next_token();
995 tok.map(|(tok, span)| match tok {
996 Ok(tok) => Ok((span.start, tok, span.end)),
997 Err(err) => Err((span.start, err, span.end)),
998 })
999 }
1000}
1001
1002impl TokenIterator for Lexer<'_> {}