Skip to main content

squawk_parser/
lexed_str.rs

1// based on https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/parser/src/lexed_str.rs
2
3use std::{num::IntErrorKind, ops};
4
5use squawk_lexer::tokenize;
6
7use crate::SyntaxKind;
8
9pub struct LexedStr<'a> {
10    text: &'a str,
11    kind: Vec<SyntaxKind>,
12    start: Vec<u32>,
13    error: Vec<LexError>,
14}
15
16struct LexError {
17    msg: String,
18    range: ops::Range<u32>,
19}
20
21impl<'a> LexedStr<'a> {
22    // TODO: rust-analyzer has an edition thing to specify things that are only
23    // available in certain version, we can do that later
24    pub fn new(text: &'a str) -> LexedStr<'a> {
25        let mut conv = Converter::new(text);
26
27        for token in tokenize(&text[conv.offset..]) {
28            let token_text = &text[conv.offset..][..token.len as usize];
29
30            conv.extend_token(&token.kind, token_text);
31        }
32
33        conv.finalize_with_eof()
34    }
35
36    // pub(crate) fn single_token(text: &'a str) -> Option<(SyntaxKind, Option<String>)> {
37    //     if text.is_empty() {
38    //         return None;
39    //     }
40
41    //     let token = tokenize(text).next()?;
42    //     if token.len as usize != text.len() {
43    //         return None;
44    //     }
45
46    //     let mut conv = Converter::new(text);
47    //     conv.extend_token(&token.kind, text);
48    //     match &*conv.res.kind {
49    //         [kind] => Some((*kind, conv.res.error.pop().map(|it| it.msg))),
50    //         _ => None,
51    //     }
52    // }
53
54    // pub(crate) fn as_str(&self) -> &str {
55    //     self.text
56    // }
57
58    pub(crate) fn len(&self) -> usize {
59        self.kind.len() - 1
60    }
61
62    // pub(crate) fn is_empty(&self) -> bool {
63    //     self.len() == 0
64    // }
65
66    pub(crate) fn kind(&self, i: usize) -> SyntaxKind {
67        assert!(i < self.len());
68        self.kind[i]
69    }
70
71    pub(crate) fn text(&self, i: usize) -> &str {
72        self.range_text(i..i + 1)
73    }
74
75    pub(crate) fn range_text(&self, r: ops::Range<usize>) -> &str {
76        assert!(r.start < r.end && r.end <= self.len());
77        let lo = self.start[r.start] as usize;
78        let hi = self.start[r.end] as usize;
79        &self.text[lo..hi]
80    }
81
82    // Naming is hard.
83    pub fn text_range(&self, i: usize) -> ops::Range<usize> {
84        assert!(i < self.len());
85        let lo = self.start[i] as usize;
86        let hi = self.start[i + 1] as usize;
87        lo..hi
88    }
89    pub fn text_start(&self, i: usize) -> usize {
90        assert!(i <= self.len());
91        self.start[i] as usize
92    }
93    // pub(crate) fn text_len(&self, i: usize) -> usize {
94    //     assert!(i < self.len());
95    //     let r = self.text_range(i);
96    //     r.end - r.start
97    // }
98
99    // pub(crate) fn error(&self, i: usize) -> Option<&str> {
100    //     assert!(i < self.len());
101    //     let err = self
102    //         .error
103    //         .binary_search_by_key(&(i as u32), |i| i.token)
104    //         .ok()?;
105    //     Some(self.error[err].msg.as_str())
106    // }
107
108    pub fn errors(&self) -> impl Iterator<Item = (&ops::Range<u32>, &str)> + '_ {
109        self.error.iter().map(|it| (&it.range, it.msg.as_str()))
110    }
111
112    fn push(&mut self, kind: SyntaxKind, offset: usize) {
113        self.kind.push(kind);
114        self.start.push(offset as u32);
115    }
116}
117
118struct Converter<'a> {
119    res: LexedStr<'a>,
120    offset: usize,
121}
122
123fn is_empty_quoted_ident(token_text: &str, uescape: bool) -> bool {
124    let inner = if uescape {
125        token_text
126            .strip_prefix(['u', 'U'])
127            .and_then(|s| s.strip_prefix('&'))
128    } else {
129        Some(token_text)
130    };
131    inner == Some("\"\"")
132}
133
134impl<'a> Converter<'a> {
135    fn new(text: &'a str) -> Self {
136        Self {
137            res: LexedStr {
138                text,
139                kind: Vec::new(),
140                start: Vec::new(),
141                error: Vec::new(),
142            },
143            offset: 0,
144        }
145    }
146
147    fn finalize_with_eof(mut self) -> LexedStr<'a> {
148        self.res.push(SyntaxKind::EOF, self.offset);
149        self.res
150    }
151
152    fn push(&mut self, kind: SyntaxKind, len: usize, err: Option<(&str, ops::Range<u32>)>) {
153        let token_start = self.offset as u32;
154        self.res.push(kind, self.offset);
155        self.offset += len;
156
157        if let Some((msg, err_range)) = err {
158            self.res.error.push(LexError {
159                msg: msg.to_owned(),
160                range: token_start + err_range.start..token_start + err_range.end,
161            });
162        }
163    }
164
165    fn extend_token(&mut self, kind: &squawk_lexer::TokenKind, token_text: &str) {
166        // A note on an intended tradeoff:
167        // We drop some useful information here (see patterns with double dots `..`)
168        // Storing that info in `SyntaxKind` is not possible due to its layout requirements of
169        // being `u16` that come from `rowan::SyntaxKind`.
170        let mut err = "";
171        let mut err_range: Option<ops::Range<u32>> = None;
172
173        let syntax_kind = {
174            match kind {
175                squawk_lexer::TokenKind::LineComment => SyntaxKind::COMMENT,
176                squawk_lexer::TokenKind::BlockComment { terminated } => {
177                    if !terminated {
178                        err = "Missing trailing `*/` symbols to terminate the block comment";
179                    }
180                    SyntaxKind::COMMENT
181                }
182
183                squawk_lexer::TokenKind::Whitespace => SyntaxKind::WHITESPACE,
184                squawk_lexer::TokenKind::Ident => {
185                    SyntaxKind::from_keyword(token_text).unwrap_or(SyntaxKind::IDENT)
186                }
187                squawk_lexer::TokenKind::Literal { kind, .. } => {
188                    self.extend_literal(token_text, kind);
189                    return;
190                }
191                squawk_lexer::TokenKind::Semi => SyntaxKind::SEMICOLON,
192                squawk_lexer::TokenKind::Comma => SyntaxKind::COMMA,
193                squawk_lexer::TokenKind::Dot => SyntaxKind::DOT,
194                squawk_lexer::TokenKind::OpenParen => SyntaxKind::L_PAREN,
195                squawk_lexer::TokenKind::CloseParen => SyntaxKind::R_PAREN,
196                squawk_lexer::TokenKind::OpenBracket => SyntaxKind::L_BRACK,
197                squawk_lexer::TokenKind::CloseBracket => SyntaxKind::R_BRACK,
198                squawk_lexer::TokenKind::OpenCurly => SyntaxKind::L_CURLY,
199                squawk_lexer::TokenKind::CloseCurly => SyntaxKind::R_CURLY,
200                squawk_lexer::TokenKind::At => SyntaxKind::AT,
201                squawk_lexer::TokenKind::Pound => SyntaxKind::POUND,
202                squawk_lexer::TokenKind::Tilde => SyntaxKind::TILDE,
203                squawk_lexer::TokenKind::Question => SyntaxKind::QUESTION,
204                squawk_lexer::TokenKind::Colon => SyntaxKind::COLON,
205                squawk_lexer::TokenKind::Eq => SyntaxKind::EQ,
206                squawk_lexer::TokenKind::Bang => SyntaxKind::BANG,
207                squawk_lexer::TokenKind::Lt => SyntaxKind::L_ANGLE,
208                squawk_lexer::TokenKind::Gt => SyntaxKind::R_ANGLE,
209                squawk_lexer::TokenKind::Minus => SyntaxKind::MINUS,
210                squawk_lexer::TokenKind::And => SyntaxKind::AMP,
211                squawk_lexer::TokenKind::Or => SyntaxKind::PIPE,
212                squawk_lexer::TokenKind::Plus => SyntaxKind::PLUS,
213                squawk_lexer::TokenKind::Star => SyntaxKind::STAR,
214                squawk_lexer::TokenKind::Slash => SyntaxKind::SLASH,
215                squawk_lexer::TokenKind::Caret => SyntaxKind::CARET,
216                squawk_lexer::TokenKind::Percent => SyntaxKind::PERCENT,
217                squawk_lexer::TokenKind::Unknown => SyntaxKind::ERROR,
218                squawk_lexer::TokenKind::Eof => SyntaxKind::EOF,
219                squawk_lexer::TokenKind::Backtick => SyntaxKind::BACKTICK,
220                squawk_lexer::TokenKind::PositionalParam {
221                    trailing_junk_start,
222                } => {
223                    let digits = &token_text[1..*trailing_junk_start as usize];
224                    if digits.is_empty() {
225                        err = "missing parameter number";
226                        err_range = Some(0..1);
227                    } else if digits
228                        .parse::<i32>()
229                        .is_err_and(|err| matches!(err.kind(), IntErrorKind::PosOverflow))
230                    {
231                        err = "parameter number too large";
232                        err_range = Some(0..*trailing_junk_start);
233                    } else if (*trailing_junk_start as usize) < token_text.len() {
234                        err = "trailing junk after positional parameter";
235                        err_range = Some(*trailing_junk_start..token_text.len() as u32);
236                    }
237                    SyntaxKind::POSITIONAL_PARAM
238                }
239                squawk_lexer::TokenKind::QuotedIdent {
240                    terminated,
241                    uescape,
242                } => {
243                    if !terminated {
244                        err = "Missing trailing \" to terminate the quoted identifier"
245                    } else if is_empty_quoted_ident(token_text, *uescape) {
246                        err = "empty delimited identifier";
247                    }
248                    SyntaxKind::IDENT
249                }
250            }
251        };
252
253        let err = if err.is_empty() { None } else { Some(err) };
254        let err = err.map(|msg| (msg, err_range.unwrap_or(0..token_text.len() as u32)));
255        self.push(syntax_kind, token_text.len(), err);
256    }
257
258    fn extend_literal(&mut self, token_text: &str, kind: &squawk_lexer::LiteralKind) {
259        let mut err: Option<String> = None;
260        let mut err_range: Option<ops::Range<u32>> = None;
261
262        let syntax_kind = match *kind {
263            squawk_lexer::LiteralKind::Int {
264                empty_int,
265                base,
266                trailing_junk_start,
267            } => {
268                if empty_int {
269                    err = Some("Missing digits after the integer base prefix".into());
270                } else {
271                    if matches!(base, squawk_lexer::Base::Binary | squawk_lexer::Base::Octal) {
272                        let prefix_len = 2u32;
273                        let digits = &token_text[prefix_len as usize..trailing_junk_start as usize];
274                        let base = base as u32;
275                        let token_start = self.offset as u32;
276                        for (i, c) in digits.char_indices() {
277                            if c != '_' && c.to_digit(base).is_none() {
278                                let start = token_start + prefix_len + i as u32;
279                                let end = start + c.len_utf8() as u32;
280                                self.res.error.push(LexError {
281                                    msg: format!("invalid digit for a base {base} literal"),
282                                    range: start..end,
283                                });
284                            }
285                        }
286                    }
287                    if (trailing_junk_start as usize) < token_text.len() {
288                        err = Some("trailing junk after numeric literal".into());
289                        err_range = Some(trailing_junk_start..token_text.len() as u32);
290                    }
291                }
292                SyntaxKind::INT_NUMBER
293            }
294            squawk_lexer::LiteralKind::Numeric {
295                empty_exponent_start,
296                trailing_junk_start,
297            } => {
298                if let Some(exponent_start) = empty_exponent_start {
299                    err = Some("Missing digits after the exponent symbol".into());
300                    err_range = Some(exponent_start..exponent_start + 1);
301                } else if (trailing_junk_start as usize) < token_text.len() {
302                    err = Some("trailing junk after numeric literal".into());
303                    err_range = Some(trailing_junk_start..token_text.len() as u32);
304                }
305                SyntaxKind::NUMERIC_NUMBER
306            }
307            squawk_lexer::LiteralKind::Str { terminated } => {
308                if !terminated {
309                    err =
310                        Some("Missing trailing `'` symbol to terminate the string literal".into());
311                }
312                SyntaxKind::STRING
313            }
314            squawk_lexer::LiteralKind::NationalStr { terminated } => {
315                if !terminated {
316                    err = Some(
317                        "Missing trailing `'` symbol to terminate the national character string literal"
318                            .into(),
319                    );
320                }
321                SyntaxKind::NATIONAL_STRING
322            }
323            squawk_lexer::LiteralKind::ByteStr { terminated } => {
324                if !terminated {
325                    err = Some(
326                        "Missing trailing `'` symbol to terminate the hex bit string literal"
327                            .into(),
328                    );
329                }
330                // digit validation in squawk_syntax
331                SyntaxKind::BYTE_STRING
332            }
333            squawk_lexer::LiteralKind::BitStr { terminated } => {
334                if !terminated {
335                    err = Some(
336                        "Missing trailing `'` symbol to terminate the bit string literal".into(),
337                    );
338                }
339                // digit validation in squawk_syntax
340                SyntaxKind::BIT_STRING
341            }
342            squawk_lexer::LiteralKind::DollarQuotedString { terminated } => {
343                if !terminated {
344                    // TODO: we could be fancier and say the ending string we're looking for
345                    err = Some("Unterminated dollar quoted string literal".into());
346                }
347                SyntaxKind::DOLLAR_QUOTED_STRING
348            }
349            squawk_lexer::LiteralKind::UnicodeEscStr { terminated } => {
350                if !terminated {
351                    err = Some(
352                        "Missing trailing `'` symbol to terminate the unicode escape string literal"
353                            .into(),
354                    );
355                }
356                // validated in squawk_syntax
357                SyntaxKind::UNICODE_ESC_STRING
358            }
359            squawk_lexer::LiteralKind::EscStr { terminated } => {
360                if !terminated {
361                    err = Some(
362                        "Missing trailing `'` symbol to terminate the escape string literal".into(),
363                    );
364                }
365                // unicode escape sequences validated in squawk_syntax
366                SyntaxKind::ESC_STRING
367            }
368        };
369
370        let err = err
371            .as_deref()
372            .map(|msg| (msg, err_range.unwrap_or(0..token_text.len() as u32)));
373        self.push(syntax_kind, token_text.len(), err);
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
380    use insta::assert_snapshot;
381
382    use super::LexedStr;
383
384    fn lex(text: &str) -> String {
385        let lexed = LexedStr::new(text);
386        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
387        let mut res = String::new();
388
389        for (range, msg) in lexed.errors() {
390            let span = range.start as usize..range.end as usize;
391            let group = Level::ERROR.primary_title(msg).element(
392                Snippet::source(text)
393                    .fold(true)
394                    .annotation(AnnotationKind::Primary.span(span)),
395            );
396            res.push_str(&renderer.render(&[group]).to_string());
397            res.push('\n');
398        }
399
400        res
401    }
402
403    #[test]
404    fn empty_int_error() {
405        assert_snapshot!(lex("select 0x;"), @"
406        error: Missing digits after the integer base prefix
407          ╭▸ 
408        1 │ select 0x;
409          ╰╴       ━━
410        ");
411    }
412
413    #[test]
414    fn empty_int_with_trailing_ident_error() {
415        assert_snapshot!(lex("select 0xg;"), @"
416        error: trailing junk after numeric literal
417          ╭▸ 
418        1 │ select 0xg;
419          ╰╴         ━
420        ");
421    }
422
423    #[test]
424    fn invalid_octal_digits_error() {
425        assert_snapshot!(lex("select 0o999;"), @"
426        error: invalid digit for a base 8 literal
427          ╭▸ 
428        1 │ select 0o999;
429          ╰╴         ━
430        error: invalid digit for a base 8 literal
431          ╭▸ 
432        1 │ select 0o999;
433          ╰╴          ━
434        error: invalid digit for a base 8 literal
435          ╭▸ 
436        1 │ select 0o999;
437          ╰╴           ━
438        ");
439    }
440
441    #[test]
442    fn invalid_binary_digits_error() {
443        assert_snapshot!(lex("select 0b234;"), @"
444        error: invalid digit for a base 2 literal
445          ╭▸ 
446        1 │ select 0b234;
447          ╰╴         ━
448        error: invalid digit for a base 2 literal
449          ╭▸ 
450        1 │ select 0b234;
451          ╰╴          ━
452        error: invalid digit for a base 2 literal
453          ╭▸ 
454        1 │ select 0b234;
455          ╰╴           ━
456        ");
457    }
458
459    #[test]
460    fn invalid_octal_digits_after_valid_error() {
461        assert_snapshot!(lex("select 0o7889;"), @"
462        error: invalid digit for a base 8 literal
463          ╭▸ 
464        1 │ select 0o7889;
465          ╰╴          ━
466        error: invalid digit for a base 8 literal
467          ╭▸ 
468        1 │ select 0o7889;
469          ╰╴           ━
470        error: invalid digit for a base 8 literal
471          ╭▸ 
472        1 │ select 0o7889;
473          ╰╴            ━
474        ");
475    }
476
477    #[test]
478    fn empty_exponent_error() {
479        assert_snapshot!(lex("select 1e;"), @"
480        error: Missing digits after the exponent symbol
481          ╭▸ 
482        1 │ select 1e;
483          ╰╴        ━
484        ");
485    }
486
487    #[test]
488    fn unterminated_string_error() {
489        assert_snapshot!(lex("select 'hello;"), @"
490        error: Missing trailing `'` symbol to terminate the string literal
491          ╭▸ 
492        1 │ select 'hello;
493          ╰╴       ━━━━━━━
494        ");
495    }
496
497    #[test]
498    fn unterminated_hex_bit_string_error() {
499        assert_snapshot!(lex("select X'1F;"), @"
500        error: Missing trailing `'` symbol to terminate the hex bit string literal
501          ╭▸ 
502        1 │ select X'1F;
503          ╰╴       ━━━━━
504        ");
505    }
506
507    #[test]
508    fn unterminated_bit_string_error() {
509        assert_snapshot!(lex("select B'101;"), @"
510        error: Missing trailing `'` symbol to terminate the bit string literal
511          ╭▸ 
512        1 │ select B'101;
513          ╰╴       ━━━━━━
514        ");
515    }
516
517    #[test]
518    fn unterminated_dollar_quoted_string_error() {
519        assert_snapshot!(lex("select $tag$hello;"), @"
520        error: Unterminated dollar quoted string literal
521          ╭▸ 
522        1 │ select $tag$hello;
523          ╰╴       ━━━━━━━━━━━
524        ");
525    }
526
527    #[test]
528    fn unterminated_unicode_escape_string_error() {
529        assert_snapshot!(lex("select U&'hello;"), @"
530        error: Missing trailing `'` symbol to terminate the unicode escape string literal
531          ╭▸ 
532        1 │ select U&'hello;
533          ╰╴       ━━━━━━━━━
534        ");
535    }
536
537    #[test]
538    fn unterminated_escape_string_error() {
539        assert_snapshot!(lex("select E'hello;"), @"
540        error: Missing trailing `'` symbol to terminate the escape string literal
541          ╭▸ 
542        1 │ select E'hello;
543          ╰╴       ━━━━━━━━
544        ");
545    }
546}