sqlite3_parser/lexer/sql/
mod.rs

1//! Adaptation/port of [`SQLite` tokenizer](http://www.sqlite.org/src/artifact?ci=trunk&filename=src/tokenize.c)
2use fallible_iterator::FallibleIterator;
3use memchr::memchr;
4
5pub use crate::dialect::TokenType;
6use crate::dialect::TokenType::*;
7use crate::dialect::{
8    is_identifier_continue, is_identifier_start, keyword_token, sentinel, MAX_KEYWORD_LEN,
9};
10use crate::parser::ast::Cmd;
11use crate::parser::parse::{yyParser, YYCODETYPE};
12use crate::parser::Context;
13
14mod error;
15#[cfg(test)]
16mod test;
17
18use crate::lexer::scan::ScanError;
19use crate::lexer::scan::Splitter;
20use crate::lexer::Scanner;
21pub use crate::parser::ParserError;
22pub use error::Error;
23
24// TODO Extract scanning stuff and move this into the parser crate
25// to make possible to use the tokenizer without depending on the parser...
26
27/// SQL parser
28pub struct Parser<'input> {
29    input: &'input [u8],
30    scanner: Scanner<Tokenizer>,
31    parser: yyParser<'input>,
32}
33
34impl<'input> Parser<'input> {
35    /// Constructor
36    pub fn new(input: &'input [u8]) -> Self {
37        let lexer = Tokenizer::new();
38        let scanner = Scanner::new(lexer);
39        let ctx = Context::new(input);
40        let parser = yyParser::new(ctx);
41        Parser {
42            input,
43            scanner,
44            parser,
45        }
46    }
47    /// Parse new `input`
48    pub fn reset(&mut self, input: &'input [u8]) {
49        self.input = input;
50        self.scanner.reset();
51    }
52    /// Current line position in input
53    pub fn line(&self) -> u64 {
54        self.scanner.line()
55    }
56    /// Current column position in input
57    pub fn column(&self) -> usize {
58        self.scanner.column()
59    }
60}
61
62/*
63 ** Return the id of the next token in input.
64 */
65fn get_token(scanner: &mut Scanner<Tokenizer>, input: &[u8]) -> Result<TokenType, Error> {
66    let mut t = {
67        let (_, token_type) = match scanner.scan(input)? {
68            (_, None, _) => {
69                return Ok(TK_EOF);
70            }
71            (_, Some(tuple), _) => tuple,
72        };
73        token_type
74    };
75    if t == TK_ID
76        || t == TK_STRING
77        || t == TK_JOIN_KW
78        || t == TK_WINDOW
79        || t == TK_OVER
80        || yyParser::parse_fallback(t as YYCODETYPE) == TK_ID as YYCODETYPE
81    {
82        t = TK_ID;
83    }
84    Ok(t)
85}
86
87/*
88 ** The following three functions are called immediately after the tokenizer
89 ** reads the keywords WINDOW, OVER and FILTER, respectively, to determine
90 ** whether the token should be treated as a keyword or an SQL identifier.
91 ** This cannot be handled by the usual lemon %fallback method, due to
92 ** the ambiguity in some constructions. e.g.
93 **
94 **   SELECT sum(x) OVER ...
95 **
96 ** In the above, "OVER" might be a keyword, or it might be an alias for the
97 ** sum(x) expression. If a "%fallback ID OVER" directive were added to
98 ** grammar, then SQLite would always treat "OVER" as an alias, making it
99 ** impossible to call a window-function without a FILTER clause.
100 **
101 ** WINDOW is treated as a keyword if:
102 **
103 **   * the following token is an identifier, or a keyword that can fallback
104 **     to being an identifier, and
105 **   * the token after than one is TK_AS.
106 **
107 ** OVER is a keyword if:
108 **
109 **   * the previous token was TK_RP, and
110 **   * the next token is either TK_LP or an identifier.
111 **
112 ** FILTER is a keyword if:
113 **
114 **   * the previous token was TK_RP, and
115 **   * the next token is TK_LP.
116 */
117fn analyze_window_keyword(
118    scanner: &mut Scanner<Tokenizer>,
119    input: &[u8],
120) -> Result<TokenType, Error> {
121    let t = get_token(scanner, input)?;
122    if t != TK_ID {
123        return Ok(TK_ID);
124    };
125    let t = get_token(scanner, input)?;
126    if t != TK_AS {
127        return Ok(TK_ID);
128    };
129    Ok(TK_WINDOW)
130}
131fn analyze_over_keyword(
132    scanner: &mut Scanner<Tokenizer>,
133    input: &[u8],
134    last_token: TokenType,
135) -> Result<TokenType, Error> {
136    if last_token == TK_RP {
137        let t = get_token(scanner, input)?;
138        if t == TK_LP || t == TK_ID {
139            return Ok(TK_OVER);
140        }
141    }
142    Ok(TK_ID)
143}
144fn analyze_filter_keyword(
145    scanner: &mut Scanner<Tokenizer>,
146    input: &[u8],
147    last_token: TokenType,
148) -> Result<TokenType, Error> {
149    if last_token == TK_RP && get_token(scanner, input)? == TK_LP {
150        return Ok(TK_FILTER);
151    }
152    Ok(TK_ID)
153}
154
155macro_rules! try_with_position {
156    ($scanner:expr, $expr:expr) => {
157        match $expr {
158            Ok(val) => val,
159            Err(err) => {
160                let mut err = Error::from(err);
161                err.position($scanner.line(), $scanner.column());
162                return Err(err);
163            }
164        }
165    };
166}
167
168impl FallibleIterator for Parser<'_> {
169    type Item = Cmd;
170    type Error = Error;
171
172    fn next(&mut self) -> Result<Option<Cmd>, Error> {
173        //print!("line: {}, column: {}: ", self.scanner.line(), self.scanner.column());
174        self.parser.ctx.reset();
175        let mut last_token_parsed = TK_EOF;
176        let mut eof = false;
177        loop {
178            let (start, (value, mut token_type), end) = match self.scanner.scan(self.input)? {
179                (_, None, _) => {
180                    eof = true;
181                    break;
182                }
183                (start, Some(tuple), end) => (start, tuple, end),
184            };
185            let token = if token_type >= TK_WINDOW {
186                debug_assert!(
187                    token_type == TK_OVER || token_type == TK_FILTER || token_type == TK_WINDOW
188                );
189                self.scanner.mark();
190                if token_type == TK_WINDOW {
191                    token_type = analyze_window_keyword(&mut self.scanner, self.input)?;
192                } else if token_type == TK_OVER {
193                    token_type =
194                        analyze_over_keyword(&mut self.scanner, self.input, last_token_parsed)?;
195                } else if token_type == TK_FILTER {
196                    token_type =
197                        analyze_filter_keyword(&mut self.scanner, self.input, last_token_parsed)?;
198                }
199                self.scanner.reset_to_mark();
200                token_type.to_token(start, value, end)
201            } else {
202                token_type.to_token(start, value, end)
203            };
204            //println!("({:?}, {:?})", token_type, token);
205            try_with_position!(self.scanner, self.parser.sqlite3Parser(token_type, token));
206            last_token_parsed = token_type;
207            if self.parser.ctx.done() {
208                //println!();
209                break;
210            }
211        }
212        if last_token_parsed == TK_EOF {
213            return Ok(None); // empty input
214        }
215        /* Upon reaching the end of input, call the parser two more times
216        with tokens TK_SEMI and 0, in that order. */
217        if eof && self.parser.ctx.is_ok() {
218            if last_token_parsed != TK_SEMI {
219                try_with_position!(
220                    self.scanner,
221                    self.parser
222                        .sqlite3Parser(TK_SEMI, sentinel(self.input.len()))
223                );
224            }
225            try_with_position!(
226                self.scanner,
227                self.parser
228                    .sqlite3Parser(TK_EOF, sentinel(self.input.len()))
229            );
230        }
231        self.parser.sqlite3ParserFinalize();
232        if let Some(e) = self.parser.ctx.error() {
233            let err = Error::ParserError(e, Some((self.scanner.line(), self.scanner.column())));
234            return Err(err);
235        }
236        let cmd = self.parser.ctx.cmd();
237        if let Some(ref cmd) = cmd {
238            if let Err(e) = cmd.check() {
239                let err = Error::ParserError(e, Some((self.scanner.line(), self.scanner.column())));
240                return Err(err);
241            }
242        }
243        Ok(cmd)
244    }
245}
246
247/// SQL token
248pub type Token<'input> = (&'input [u8], TokenType);
249
250/// SQL lexer
251#[derive(Default)]
252pub struct Tokenizer {}
253
254impl Tokenizer {
255    /// Constructor
256    pub fn new() -> Self {
257        Self {}
258    }
259}
260
261/// ```rust
262/// use sqlite3_parser::lexer::sql::Tokenizer;
263/// use sqlite3_parser::lexer::Scanner;
264///
265/// let tokenizer = Tokenizer::new();
266/// let input = b"PRAGMA parser_trace=ON;";
267/// let mut s = Scanner::new(tokenizer);
268/// let Ok((_, Some((token1, _)), _)) = s.scan(input) else { panic!() };
269/// s.scan(input).unwrap();
270/// assert!(b"PRAGMA".eq_ignore_ascii_case(token1));
271/// ```
272impl Splitter for Tokenizer {
273    type Error = Error;
274    type TokenType = TokenType;
275
276    fn split<'input>(
277        &mut self,
278        data: &'input [u8],
279    ) -> Result<(Option<Token<'input>>, usize), Error> {
280        if data[0].is_ascii_whitespace() {
281            // eat as much space as possible
282            return Ok((
283                None,
284                match data.iter().skip(1).position(|&b| !b.is_ascii_whitespace()) {
285                    Some(i) => i + 1,
286                    _ => data.len(),
287                },
288            ));
289        }
290        match data[0] {
291            b'-' => {
292                if let Some(b) = data.get(1) {
293                    if *b == b'-' {
294                        // eat comment
295                        if let Some(i) = memchr(b'\n', data) {
296                            Ok((None, i + 1))
297                        } else {
298                            Ok((None, data.len()))
299                        }
300                    } else if *b == b'>' {
301                        if let Some(b) = data.get(2) {
302                            if *b == b'>' {
303                                return Ok((Some((&data[..3], TK_PTR)), 3));
304                            }
305                        }
306                        Ok((Some((&data[..2], TK_PTR)), 2))
307                    } else {
308                        Ok((Some((&data[..1], TK_MINUS)), 1))
309                    }
310                } else {
311                    Ok((Some((&data[..1], TK_MINUS)), 1))
312                }
313            }
314            b'(' => Ok((Some((&data[..1], TK_LP)), 1)),
315            b')' => Ok((Some((&data[..1], TK_RP)), 1)),
316            b';' => Ok((Some((&data[..1], TK_SEMI)), 1)),
317            b'+' => Ok((Some((&data[..1], TK_PLUS)), 1)),
318            b'*' => Ok((Some((&data[..1], TK_STAR)), 1)),
319            b'/' => {
320                if let Some(b) = data.get(1) {
321                    if *b == b'*' {
322                        // eat comment
323                        let mut pb = 0;
324                        let mut end = None;
325                        for (i, b) in data.iter().enumerate().skip(2) {
326                            if *b == b'/' && pb == b'*' {
327                                end = Some(i);
328                                break;
329                            }
330                            pb = *b;
331                        }
332                        if let Some(i) = end {
333                            Ok((None, i + 1))
334                        } else {
335                            Err(Error::UnterminatedBlockComment(None))
336                        }
337                    } else {
338                        Ok((Some((&data[..1], TK_SLASH)), 1))
339                    }
340                } else {
341                    Ok((Some((&data[..1], TK_SLASH)), 1))
342                }
343            }
344            b'%' => Ok((Some((&data[..1], TK_REM)), 1)),
345            b'=' => {
346                if let Some(b) = data.get(1) {
347                    Ok(if *b == b'=' {
348                        (Some((&data[..2], TK_EQ)), 2)
349                    } else {
350                        (Some((&data[..1], TK_EQ)), 1)
351                    })
352                } else {
353                    Ok((Some((&data[..1], TK_EQ)), 1))
354                }
355            }
356            b'<' => {
357                if let Some(b) = data.get(1) {
358                    Ok(match *b {
359                        b'=' => (Some((&data[..2], TK_LE)), 2),
360                        b'>' => (Some((&data[..2], TK_NE)), 2),
361                        b'<' => (Some((&data[..2], TK_LSHIFT)), 2),
362                        _ => (Some((&data[..1], TK_LT)), 1),
363                    })
364                } else {
365                    Ok((Some((&data[..1], TK_LT)), 1))
366                }
367            }
368            b'>' => {
369                if let Some(b) = data.get(1) {
370                    Ok(match *b {
371                        b'=' => (Some((&data[..2], TK_GE)), 2),
372                        b'>' => (Some((&data[..2], TK_RSHIFT)), 2),
373                        _ => (Some((&data[..1], TK_GT)), 1),
374                    })
375                } else {
376                    Ok((Some((&data[..1], TK_GT)), 1))
377                }
378            }
379            b'!' => {
380                if let Some(b) = data.get(1) {
381                    if *b == b'=' {
382                        Ok((Some((&data[..2], TK_NE)), 2))
383                    } else {
384                        Err(Error::ExpectedEqualsSign(None))
385                    }
386                } else {
387                    Err(Error::ExpectedEqualsSign(None))
388                }
389            }
390            b'|' => {
391                if let Some(b) = data.get(1) {
392                    Ok(if *b == b'|' {
393                        (Some((&data[..2], TK_CONCAT)), 2)
394                    } else {
395                        (Some((&data[..1], TK_BITOR)), 1)
396                    })
397                } else {
398                    Ok((Some((&data[..1], TK_BITOR)), 1))
399                }
400            }
401            b',' => Ok((Some((&data[..1], TK_COMMA)), 1)),
402            b'&' => Ok((Some((&data[..1], TK_BITAND)), 1)),
403            b'~' => Ok((Some((&data[..1], TK_BITNOT)), 1)),
404            quote @ (b'`' | b'\'' | b'"') => literal(data, quote),
405            b'.' => {
406                if let Some(b) = data.get(1) {
407                    if b.is_ascii_digit() {
408                        fractional_part(data, 0)
409                    } else {
410                        Ok((Some((&data[..1], TK_DOT)), 1))
411                    }
412                } else {
413                    Ok((Some((&data[..1], TK_DOT)), 1))
414                }
415            }
416            b'0'..=b'9' => number(data),
417            b'[' => {
418                if let Some(i) = memchr(b']', data) {
419                    // Keep original quotes / '[' ... ’]'
420                    Ok((Some((&data[0..=i], TK_ID)), i + 1))
421                } else {
422                    Err(Error::UnterminatedBracket(None))
423                }
424            }
425            b'?' => {
426                match data.iter().skip(1).position(|&b| !b.is_ascii_digit()) {
427                    Some(i) => {
428                        // do not include the '?' in the token
429                        Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
430                    }
431                    None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
432                }
433            }
434            b'$' | b'@' | b'#' | b':' => {
435                match data
436                    .iter()
437                    .skip(1)
438                    .position(|&b| !is_identifier_continue(b))
439                {
440                    Some(0) => Err(Error::BadVariableName(None)),
441                    Some(i) => {
442                        // '$' is included as part of the name
443                        Ok((Some((&data[..=i], TK_VARIABLE)), i + 1))
444                    }
445                    None => {
446                        if data.len() == 1 {
447                            return Err(Error::BadVariableName(None));
448                        }
449                        Ok((Some((data, TK_VARIABLE)), data.len()))
450                    }
451                }
452            }
453            b if is_identifier_start(b) => {
454                if b == b'x' || b == b'X' {
455                    if let Some(&b'\'') = data.get(1) {
456                        blob_literal(data)
457                    } else {
458                        Ok(self.identifierish(data))
459                    }
460                } else {
461                    Ok(self.identifierish(data))
462                }
463            }
464            _ => Err(Error::UnrecognizedToken(None)),
465        }
466    }
467}
468
469fn literal(data: &[u8], quote: u8) -> Result<(Option<Token<'_>>, usize), Error> {
470    debug_assert_eq!(data[0], quote);
471    let tt = if quote == b'\'' { TK_STRING } else { TK_ID };
472    let mut pb = 0;
473    let mut end = None;
474    // data[0] == quote => skip(1)
475    for (i, b) in data.iter().enumerate().skip(1) {
476        if *b == quote {
477            if pb == quote {
478                // escaped quote
479                pb = 0;
480                continue;
481            }
482        } else if pb == quote {
483            end = Some(i);
484            break;
485        }
486        pb = *b;
487    }
488    if end.is_some() || pb == quote {
489        let i = match end {
490            Some(i) => i,
491            _ => data.len(),
492        };
493        // keep original quotes in the token
494        Ok((Some((&data[0..i], tt)), i))
495    } else {
496        Err(Error::UnterminatedLiteral(None))
497    }
498}
499
500fn blob_literal(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
501    debug_assert!(data[0] == b'x' || data[0] == b'X');
502    debug_assert_eq!(data[1], b'\'');
503    if let Some((i, b)) = data
504        .iter()
505        .enumerate()
506        .skip(2)
507        .find(|&(_, &b)| !b.is_ascii_hexdigit())
508    {
509        if *b != b'\'' || i % 2 != 0 {
510            return Err(Error::MalformedBlobLiteral(None));
511        }
512        Ok((Some((&data[2..i], TK_BLOB)), i + 1))
513    } else {
514        Err(Error::MalformedBlobLiteral(None))
515    }
516}
517
518fn number(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
519    debug_assert!(data[0].is_ascii_digit());
520    if data[0] == b'0' {
521        if let Some(b) = data.get(1) {
522            if *b == b'x' || *b == b'X' {
523                return hex_integer(data);
524            }
525        } else {
526            return Ok((Some((data, TK_INTEGER)), data.len()));
527        }
528    }
529    if let Some((i, b)) = find_end_of_number(data, 1, u8::is_ascii_digit)? {
530        if b == b'.' {
531            return fractional_part(data, i);
532        } else if b == b'e' || b == b'E' {
533            return exponential_part(data, i);
534        } else if is_identifier_start(b) {
535            return Err(Error::BadNumber(None));
536        }
537        Ok((Some((&data[..i], TK_INTEGER)), i))
538    } else {
539        Ok((Some((data, TK_INTEGER)), data.len()))
540    }
541}
542
543fn hex_integer(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
544    debug_assert_eq!(data[0], b'0');
545    debug_assert!(data[1] == b'x' || data[1] == b'X');
546    if let Some((i, b)) = find_end_of_number(data, 2, u8::is_ascii_hexdigit)? {
547        // Must not be empty (Ox is invalid)
548        if i == 2 || is_identifier_start(b) {
549            return Err(Error::MalformedHexInteger(None));
550        }
551        Ok((Some((&data[..i], TK_INTEGER)), i))
552    } else {
553        // Must not be empty (Ox is invalid)
554        if data.len() == 2 {
555            return Err(Error::MalformedHexInteger(None));
556        }
557        Ok((Some((data, TK_INTEGER)), data.len()))
558    }
559}
560
561fn fractional_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
562    debug_assert_eq!(data[i], b'.');
563    if let Some((i, b)) = find_end_of_number(data, i + 1, u8::is_ascii_digit)? {
564        if b == b'e' || b == b'E' {
565            return exponential_part(data, i);
566        } else if is_identifier_start(b) {
567            return Err(Error::BadNumber(None));
568        }
569        Ok((Some((&data[..i], TK_FLOAT)), i))
570    } else {
571        Ok((Some((data, TK_FLOAT)), data.len()))
572    }
573}
574
575fn exponential_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
576    debug_assert!(data[i] == b'e' || data[i] == b'E');
577    // data[i] == 'e'|'E'
578    if let Some(b) = data.get(i + 1) {
579        let i = if *b == b'+' || *b == b'-' { i + 1 } else { i };
580        if let Some((j, b)) = find_end_of_number(data, i + 1, u8::is_ascii_digit)? {
581            if j == i + 1 || is_identifier_start(b) {
582                return Err(Error::BadNumber(None));
583            }
584            Ok((Some((&data[..j], TK_FLOAT)), j))
585        } else {
586            if data.len() == i + 1 {
587                return Err(Error::BadNumber(None));
588            }
589            Ok((Some((data, TK_FLOAT)), data.len()))
590        }
591    } else {
592        Err(Error::BadNumber(None))
593    }
594}
595
596fn find_end_of_number(
597    data: &[u8],
598    i: usize,
599    test: fn(&u8) -> bool,
600) -> Result<Option<(usize, u8)>, Error> {
601    for (j, &b) in data.iter().enumerate().skip(i) {
602        if test(&b) {
603            continue;
604        } else if b == b'_' {
605            if j >= 1 && data.get(j - 1).is_some_and(test) && data.get(j + 1).is_some_and(test) {
606                continue;
607            }
608            return Err(Error::BadNumber(None));
609        } else {
610            return Ok(Some((j, b)));
611        }
612    }
613    Ok(None)
614}
615
616impl Tokenizer {
617    fn identifierish<'input>(&mut self, data: &'input [u8]) -> (Option<Token<'input>>, usize) {
618        debug_assert!(is_identifier_start(data[0]));
619        // data[0] is_identifier_start => skip(1)
620        let end = data
621            .iter()
622            .skip(1)
623            .position(|&b| !is_identifier_continue(b));
624        let i = match end {
625            Some(i) => i + 1,
626            _ => data.len(),
627        };
628        let word = &data[..i];
629        let tt = if word.len() >= 2 && word.len() <= MAX_KEYWORD_LEN && word.is_ascii() {
630            keyword_token(word).unwrap_or(TK_ID)
631        } else {
632            TK_ID
633        };
634        (Some((word, tt)), i)
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::Tokenizer;
641    use crate::dialect::TokenType;
642    use crate::lexer::sql::Error;
643    use crate::lexer::Scanner;
644
645    #[test]
646    fn fallible_iterator() -> Result<(), Error> {
647        let tokenizer = Tokenizer::new();
648        let input = b"PRAGMA parser_trace=ON;";
649        let mut s = Scanner::new(tokenizer);
650        expect_token(&mut s, input, b"PRAGMA", TokenType::TK_PRAGMA)?;
651        expect_token(&mut s, input, b"parser_trace", TokenType::TK_ID)?;
652        Ok(())
653    }
654
655    #[test]
656    fn invalid_number_literal() -> Result<(), Error> {
657        let tokenizer = Tokenizer::new();
658        let input = b"SELECT 1E;";
659        let mut s = Scanner::new(tokenizer);
660        expect_token(&mut s, input, b"SELECT", TokenType::TK_SELECT)?;
661        let err = s.scan(input).unwrap_err();
662        assert!(matches!(err, Error::BadNumber(_)));
663        Ok(())
664    }
665
666    fn expect_token(
667        s: &mut Scanner<Tokenizer>,
668        input: &[u8],
669        token: &[u8],
670        token_type: TokenType,
671    ) -> Result<(), Error> {
672        let (t, tt) = s.scan(input)?.1.unwrap();
673        assert_eq!(token, t);
674        assert_eq!(token_type, tt);
675        Ok(())
676    }
677}