sqlite3_parser/lexer/sql/
mod.rs

1//! Adaptation/port of [`SQLite` tokenizer](http://www.sqlite.org/src/artifact?ci=trunk&filename=src/tokenize.c)
2use fallible_iterator::FallibleIterator;
3use memchr::memchr;
4
5pub use crate::dialect::TokenType;
6use crate::dialect::TokenType::*;
7use crate::dialect::{
8    is_identifier_continue, is_identifier_start, keyword_token, sentinel, MAX_KEYWORD_LEN,
9};
10use crate::parser::ast::Cmd;
11use crate::parser::parse::{yyParser, YYCODETYPE};
12use crate::parser::Context;
13
14mod error;
15#[cfg(test)]
16mod test;
17
18use crate::lexer::scan::{Pos, ScanError, Splitter};
19use crate::lexer::Scanner;
20pub use crate::parser::ParserError;
21pub use error::Error;
22
23// TODO Extract scanning stuff and move this into the parser crate
24// to make possible to use the tokenizer without depending on the parser...
25
26/// SQL parser
27pub struct Parser<'input> {
28    input: &'input [u8],
29    scanner: Scanner<Tokenizer>,
30    parser: yyParser<'input>,
31}
32
33impl<'input> Parser<'input> {
34    /// Constructor
35    pub fn new(input: &'input [u8]) -> Self {
36        let lexer = Tokenizer::new();
37        let scanner = Scanner::new(lexer);
38        let ctx = Context::new(input);
39        let parser = yyParser::new(ctx);
40        Parser {
41            input,
42            scanner,
43            parser,
44        }
45    }
46    /// Parse new `input`
47    pub fn reset(&mut self, input: &'input [u8]) {
48        self.input = input;
49        self.scanner.reset();
50    }
51    /// Current position in input
52    pub fn position(&self) -> Pos {
53        self.scanner.position(self.input)
54    }
55}
56
57/*
58 ** Return the id of the next token in input.
59 */
60fn get_token(scanner: &mut Scanner<Tokenizer>, input: &[u8]) -> Result<TokenType, Error> {
61    let mut t = {
62        let (_, token_type) = match scanner.scan(input)? {
63            (_, None, _) => {
64                return Ok(TK_EOF);
65            }
66            (_, Some(tuple), _) => tuple,
67        };
68        token_type
69    };
70    if t == TK_ID
71        || t == TK_STRING
72        || t == TK_JOIN_KW
73        || t == TK_WINDOW
74        || t == TK_OVER
75        || yyParser::parse_fallback(t as YYCODETYPE) == TK_ID as YYCODETYPE
76    {
77        t = TK_ID;
78    }
79    Ok(t)
80}
81
82/*
83 ** The following three functions are called immediately after the tokenizer
84 ** reads the keywords WINDOW, OVER and FILTER, respectively, to determine
85 ** whether the token should be treated as a keyword or an SQL identifier.
86 ** This cannot be handled by the usual lemon %fallback method, due to
87 ** the ambiguity in some constructions. e.g.
88 **
89 **   SELECT sum(x) OVER ...
90 **
91 ** In the above, "OVER" might be a keyword, or it might be an alias for the
92 ** sum(x) expression. If a "%fallback ID OVER" directive were added to
93 ** grammar, then SQLite would always treat "OVER" as an alias, making it
94 ** impossible to call a window-function without a FILTER clause.
95 **
96 ** WINDOW is treated as a keyword if:
97 **
98 **   * the following token is an identifier, or a keyword that can fallback
99 **     to being an identifier, and
100 **   * the token after than one is TK_AS.
101 **
102 ** OVER is a keyword if:
103 **
104 **   * the previous token was TK_RP, and
105 **   * the next token is either TK_LP or an identifier.
106 **
107 ** FILTER is a keyword if:
108 **
109 **   * the previous token was TK_RP, and
110 **   * the next token is TK_LP.
111 */
112fn analyze_window_keyword(
113    scanner: &mut Scanner<Tokenizer>,
114    input: &[u8],
115) -> Result<TokenType, Error> {
116    let t = get_token(scanner, input)?;
117    if t != TK_ID {
118        return Ok(TK_ID);
119    };
120    let t = get_token(scanner, input)?;
121    if t != TK_AS {
122        return Ok(TK_ID);
123    };
124    Ok(TK_WINDOW)
125}
126fn analyze_over_keyword(
127    scanner: &mut Scanner<Tokenizer>,
128    input: &[u8],
129    last_token: TokenType,
130) -> Result<TokenType, Error> {
131    if last_token == TK_RP {
132        let t = get_token(scanner, input)?;
133        if t == TK_LP || t == TK_ID {
134            return Ok(TK_OVER);
135        }
136    }
137    Ok(TK_ID)
138}
139fn analyze_filter_keyword(
140    scanner: &mut Scanner<Tokenizer>,
141    input: &[u8],
142    last_token: TokenType,
143) -> Result<TokenType, Error> {
144    if last_token == TK_RP && get_token(scanner, input)? == TK_LP {
145        return Ok(TK_FILTER);
146    }
147    Ok(TK_ID)
148}
149
150macro_rules! try_with_position {
151    ($input:expr, $offset:expr, $expr:expr) => {
152        match $expr {
153            Ok(val) => val,
154            Err(err) => {
155                let mut err = Error::from(err);
156                err.position(Pos::from($input, $offset));
157                return Err(err);
158            }
159        }
160    };
161}
162
163impl FallibleIterator for Parser<'_> {
164    type Item = Cmd;
165    type Error = Error;
166
167    fn next(&mut self) -> Result<Option<Cmd>, Error> {
168        //print!("line: {}, column: {}: ", self.scanner.line(), self.scanner.column());
169        self.parser.ctx.reset();
170        let mut last_token_parsed = TK_EOF;
171        let offset;
172        let mut eof = false;
173        loop {
174            let (start, (value, mut token_type), end) = match self.scanner.scan(self.input)? {
175                (start, None, _) => {
176                    offset = start;
177                    eof = true;
178                    break;
179                }
180                (start, Some(tuple), end) => (start, tuple, end),
181            };
182            let token = if token_type >= TK_WINDOW {
183                debug_assert!(
184                    token_type == TK_OVER || token_type == TK_FILTER || token_type == TK_WINDOW
185                );
186                self.scanner.mark();
187                if token_type == TK_WINDOW {
188                    token_type = analyze_window_keyword(&mut self.scanner, self.input)?;
189                } else if token_type == TK_OVER {
190                    token_type =
191                        analyze_over_keyword(&mut self.scanner, self.input, last_token_parsed)?;
192                } else if token_type == TK_FILTER {
193                    token_type =
194                        analyze_filter_keyword(&mut self.scanner, self.input, last_token_parsed)?;
195                }
196                self.scanner.reset_to_mark();
197                token_type.to_token(start, value, end)
198            } else {
199                token_type.to_token(start, value, end)
200            };
201            //println!("({:?}, {:?})", token_type, token);
202            try_with_position!(
203                self.input,
204                start,
205                self.parser.sqlite3Parser(token_type, token)
206            );
207            last_token_parsed = token_type;
208            if self.parser.ctx.done() {
209                //println!();
210                offset = start;
211                break;
212            }
213        }
214        if last_token_parsed == TK_EOF {
215            return Ok(None); // empty input
216        }
217        /* Upon reaching the end of input, call the parser two more times
218        with tokens TK_SEMI and 0, in that order. */
219        if eof && self.parser.ctx.is_ok() {
220            if last_token_parsed != TK_SEMI {
221                try_with_position!(
222                    self.input,
223                    offset,
224                    self.parser
225                        .sqlite3Parser(TK_SEMI, sentinel(self.input.len()))
226                );
227            }
228            try_with_position!(
229                self.input,
230                offset,
231                self.parser
232                    .sqlite3Parser(TK_EOF, sentinel(self.input.len()))
233            );
234        }
235        self.parser.sqlite3ParserFinalize();
236        if let Some(e) = self.parser.ctx.error() {
237            let err = Error::ParserError(e, Some(Pos::from(self.input, offset)));
238            return Err(err);
239        }
240        let cmd = self.parser.ctx.cmd();
241        #[cfg(feature = "extra_checks")]
242        if let Some(ref cmd) = cmd {
243            if let Err(e) = cmd.check() {
244                let err = Error::ParserError(e, Some(Pos::from(self.input, offset)));
245                return Err(err);
246            }
247        }
248        Ok(cmd)
249    }
250}
251
252/// SQL token
253pub type Token<'input> = (&'input [u8], TokenType);
254
255/// SQL lexer
256#[derive(Default)]
257pub struct Tokenizer {}
258
259impl Tokenizer {
260    /// Constructor
261    pub fn new() -> Self {
262        Self {}
263    }
264}
265
266/// ```rust
267/// use sqlite3_parser::lexer::sql::Tokenizer;
268/// use sqlite3_parser::lexer::Scanner;
269///
270/// let tokenizer = Tokenizer::new();
271/// let input = b"PRAGMA parser_trace=ON;";
272/// let mut s = Scanner::new(tokenizer);
273/// let Ok((_, Some((token1, _)), _)) = s.scan(input) else { panic!() };
274/// s.scan(input).unwrap();
275/// assert!(b"PRAGMA".eq_ignore_ascii_case(token1));
276/// ```
277impl Splitter for Tokenizer {
278    type Error = Error;
279    type TokenType = TokenType;
280
281    fn split<'input>(
282        &mut self,
283        data: &'input [u8],
284    ) -> Result<(Option<Token<'input>>, usize), Error> {
285        if data[0].is_ascii_whitespace() {
286            // eat as much space as possible
287            return Ok((
288                None,
289                match data.iter().skip(1).position(|&b| !b.is_ascii_whitespace()) {
290                    Some(i) => i + 1,
291                    _ => data.len(),
292                },
293            ));
294        }
295        match data[0] {
296            b'-' => {
297                if let Some(b) = data.get(1) {
298                    if *b == b'-' {
299                        // eat comment
300                        if let Some(i) = memchr(b'\n', data) {
301                            Ok((None, i + 1))
302                        } else {
303                            Ok((None, data.len()))
304                        }
305                    } else if *b == b'>' {
306                        if let Some(b) = data.get(2) {
307                            if *b == b'>' {
308                                return Ok((Some((&data[..3], TK_PTR)), 3));
309                            }
310                        }
311                        Ok((Some((&data[..2], TK_PTR)), 2))
312                    } else {
313                        Ok((Some((&data[..1], TK_MINUS)), 1))
314                    }
315                } else {
316                    Ok((Some((&data[..1], TK_MINUS)), 1))
317                }
318            }
319            b'(' => Ok((Some((&data[..1], TK_LP)), 1)),
320            b')' => Ok((Some((&data[..1], TK_RP)), 1)),
321            b';' => Ok((Some((&data[..1], TK_SEMI)), 1)),
322            b'+' => Ok((Some((&data[..1], TK_PLUS)), 1)),
323            b'*' => Ok((Some((&data[..1], TK_STAR)), 1)),
324            b'/' => {
325                if let Some(b) = data.get(1) {
326                    if *b == b'*' {
327                        // eat comment
328                        let mut pb = 0;
329                        let mut end = None;
330                        for (i, b) in data.iter().enumerate().skip(2) {
331                            if *b == b'/' && pb == b'*' {
332                                end = Some(i);
333                                break;
334                            }
335                            pb = *b;
336                        }
337                        if let Some(i) = end {
338                            Ok((None, i + 1))
339                        } else {
340                            Err(Error::UnterminatedBlockComment(None))
341                        }
342                    } else {
343                        Ok((Some((&data[..1], TK_SLASH)), 1))
344                    }
345                } else {
346                    Ok((Some((&data[..1], TK_SLASH)), 1))
347                }
348            }
349            b'%' => Ok((Some((&data[..1], TK_REM)), 1)),
350            b'=' => {
351                if let Some(b) = data.get(1) {
352                    Ok(if *b == b'=' {
353                        (Some((&data[..2], TK_EQ)), 2)
354                    } else {
355                        (Some((&data[..1], TK_EQ)), 1)
356                    })
357                } else {
358                    Ok((Some((&data[..1], TK_EQ)), 1))
359                }
360            }
361            b'<' => {
362                if let Some(b) = data.get(1) {
363                    Ok(match *b {
364                        b'=' => (Some((&data[..2], TK_LE)), 2),
365                        b'>' => (Some((&data[..2], TK_NE)), 2),
366                        b'<' => (Some((&data[..2], TK_LSHIFT)), 2),
367                        _ => (Some((&data[..1], TK_LT)), 1),
368                    })
369                } else {
370                    Ok((Some((&data[..1], TK_LT)), 1))
371                }
372            }
373            b'>' => {
374                if let Some(b) = data.get(1) {
375                    Ok(match *b {
376                        b'=' => (Some((&data[..2], TK_GE)), 2),
377                        b'>' => (Some((&data[..2], TK_RSHIFT)), 2),
378                        _ => (Some((&data[..1], TK_GT)), 1),
379                    })
380                } else {
381                    Ok((Some((&data[..1], TK_GT)), 1))
382                }
383            }
384            b'!' => {
385                if let Some(b) = data.get(1) {
386                    if *b == b'=' {
387                        Ok((Some((&data[..2], TK_NE)), 2))
388                    } else {
389                        Err(Error::ExpectedEqualsSign(None))
390                    }
391                } else {
392                    Err(Error::ExpectedEqualsSign(None))
393                }
394            }
395            b'|' => {
396                if let Some(b) = data.get(1) {
397                    Ok(if *b == b'|' {
398                        (Some((&data[..2], TK_CONCAT)), 2)
399                    } else {
400                        (Some((&data[..1], TK_BITOR)), 1)
401                    })
402                } else {
403                    Ok((Some((&data[..1], TK_BITOR)), 1))
404                }
405            }
406            b',' => Ok((Some((&data[..1], TK_COMMA)), 1)),
407            b'&' => Ok((Some((&data[..1], TK_BITAND)), 1)),
408            b'~' => Ok((Some((&data[..1], TK_BITNOT)), 1)),
409            quote @ (b'`' | b'\'' | b'"') => literal(data, quote),
410            b'.' => {
411                if let Some(b) = data.get(1) {
412                    if b.is_ascii_digit() {
413                        fractional_part(data, 0)
414                    } else {
415                        Ok((Some((&data[..1], TK_DOT)), 1))
416                    }
417                } else {
418                    Ok((Some((&data[..1], TK_DOT)), 1))
419                }
420            }
421            b'0'..=b'9' => number(data),
422            b'[' => {
423                if let Some(i) = memchr(b']', data) {
424                    // Keep original quotes / '[' ... ’]'
425                    Ok((Some((&data[0..=i], TK_ID)), i + 1))
426                } else {
427                    Err(Error::UnterminatedBracket(None))
428                }
429            }
430            b'?' => {
431                match data.iter().skip(1).position(|&b| !b.is_ascii_digit()) {
432                    Some(i) => {
433                        // do not include the '?' in the token
434                        Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
435                    }
436                    None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
437                }
438            }
439            b'$' | b'@' | b'#' | b':' => {
440                match data
441                    .iter()
442                    .skip(1)
443                    .position(|&b| !is_identifier_continue(b))
444                {
445                    Some(0) => Err(Error::BadVariableName(None)),
446                    Some(i) => {
447                        // '$' is included as part of the name
448                        Ok((Some((&data[..=i], TK_VARIABLE)), i + 1))
449                    }
450                    None => {
451                        if data.len() == 1 {
452                            return Err(Error::BadVariableName(None));
453                        }
454                        Ok((Some((data, TK_VARIABLE)), data.len()))
455                    }
456                }
457            }
458            b if is_identifier_start(b) => {
459                if b == b'x' || b == b'X' {
460                    if let Some(&b'\'') = data.get(1) {
461                        blob_literal(data)
462                    } else {
463                        Ok(self.identifierish(data))
464                    }
465                } else {
466                    Ok(self.identifierish(data))
467                }
468            }
469            _ => Err(Error::UnrecognizedToken(None)),
470        }
471    }
472}
473
474fn literal(data: &[u8], quote: u8) -> Result<(Option<Token<'_>>, usize), Error> {
475    debug_assert_eq!(data[0], quote);
476    let tt = if quote == b'\'' { TK_STRING } else { TK_ID };
477    let mut pb = 0;
478    let mut end = None;
479    // data[0] == quote => skip(1)
480    for (i, b) in data.iter().enumerate().skip(1) {
481        if *b == quote {
482            if pb == quote {
483                // escaped quote
484                pb = 0;
485                continue;
486            }
487        } else if pb == quote {
488            end = Some(i);
489            break;
490        }
491        pb = *b;
492    }
493    if end.is_some() || pb == quote {
494        let i = match end {
495            Some(i) => i,
496            _ => data.len(),
497        };
498        // keep original quotes in the token
499        Ok((Some((&data[0..i], tt)), i))
500    } else {
501        Err(Error::UnterminatedLiteral(None))
502    }
503}
504
505fn blob_literal(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
506    debug_assert!(data[0] == b'x' || data[0] == b'X');
507    debug_assert_eq!(data[1], b'\'');
508    if let Some((i, b)) = data
509        .iter()
510        .enumerate()
511        .skip(2)
512        .find(|&(_, &b)| !b.is_ascii_hexdigit())
513    {
514        if *b != b'\'' || i % 2 != 0 {
515            return Err(Error::MalformedBlobLiteral(None));
516        }
517        Ok((Some((&data[2..i], TK_BLOB)), i + 1))
518    } else {
519        Err(Error::MalformedBlobLiteral(None))
520    }
521}
522
523fn number(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
524    debug_assert!(data[0].is_ascii_digit());
525    if data[0] == b'0' {
526        if let Some(b) = data.get(1) {
527            if *b == b'x' || *b == b'X' {
528                return hex_integer(data);
529            }
530        } else {
531            return Ok((Some((data, TK_INTEGER)), data.len()));
532        }
533    }
534    if let Some((i, b)) = find_end_of_number(data, 1, u8::is_ascii_digit)? {
535        if b == b'.' {
536            return fractional_part(data, i);
537        } else if b == b'e' || b == b'E' {
538            return exponential_part(data, i);
539        } else if is_identifier_start(b) {
540            return Err(Error::BadNumber(None));
541        }
542        Ok((Some((&data[..i], TK_INTEGER)), i))
543    } else {
544        Ok((Some((data, TK_INTEGER)), data.len()))
545    }
546}
547
548fn hex_integer(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
549    debug_assert_eq!(data[0], b'0');
550    debug_assert!(data[1] == b'x' || data[1] == b'X');
551    if let Some((i, b)) = find_end_of_number(data, 2, u8::is_ascii_hexdigit)? {
552        // Must not be empty (Ox is invalid)
553        if i == 2 || is_identifier_start(b) {
554            return Err(Error::MalformedHexInteger(None));
555        }
556        Ok((Some((&data[..i], TK_INTEGER)), i))
557    } else {
558        // Must not be empty (Ox is invalid)
559        if data.len() == 2 {
560            return Err(Error::MalformedHexInteger(None));
561        }
562        Ok((Some((data, TK_INTEGER)), data.len()))
563    }
564}
565
566fn fractional_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
567    debug_assert_eq!(data[i], b'.');
568    if let Some((i, b)) = find_end_of_number(data, i + 1, u8::is_ascii_digit)? {
569        if b == b'e' || b == b'E' {
570            return exponential_part(data, i);
571        } else if is_identifier_start(b) {
572            return Err(Error::BadNumber(None));
573        }
574        Ok((Some((&data[..i], TK_FLOAT)), i))
575    } else {
576        Ok((Some((data, TK_FLOAT)), data.len()))
577    }
578}
579
580fn exponential_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
581    debug_assert!(data[i] == b'e' || data[i] == b'E');
582    // data[i] == 'e'|'E'
583    if let Some(b) = data.get(i + 1) {
584        let i = if *b == b'+' || *b == b'-' { i + 1 } else { i };
585        if let Some((j, b)) = find_end_of_number(data, i + 1, u8::is_ascii_digit)? {
586            if j == i + 1 || is_identifier_start(b) {
587                return Err(Error::BadNumber(None));
588            }
589            Ok((Some((&data[..j], TK_FLOAT)), j))
590        } else {
591            if data.len() == i + 1 {
592                return Err(Error::BadNumber(None));
593            }
594            Ok((Some((data, TK_FLOAT)), data.len()))
595        }
596    } else {
597        Err(Error::BadNumber(None))
598    }
599}
600
601fn find_end_of_number(
602    data: &[u8],
603    i: usize,
604    test: fn(&u8) -> bool,
605) -> Result<Option<(usize, u8)>, Error> {
606    for (j, &b) in data.iter().enumerate().skip(i) {
607        if test(&b) {
608            continue;
609        } else if b == b'_' {
610            if j >= 1 && data.get(j - 1).is_some_and(test) && data.get(j + 1).is_some_and(test) {
611                continue;
612            }
613            return Err(Error::BadNumber(None));
614        } else {
615            return Ok(Some((j, b)));
616        }
617    }
618    Ok(None)
619}
620
621impl Tokenizer {
622    fn identifierish<'input>(&mut self, data: &'input [u8]) -> (Option<Token<'input>>, usize) {
623        debug_assert!(is_identifier_start(data[0]));
624        // data[0] is_identifier_start => skip(1)
625        let end = data
626            .iter()
627            .skip(1)
628            .position(|&b| !is_identifier_continue(b));
629        let i = match end {
630            Some(i) => i + 1,
631            _ => data.len(),
632        };
633        let word = &data[..i];
634        let tt = if word.len() >= 2 && word.len() <= MAX_KEYWORD_LEN && word.is_ascii() {
635            keyword_token(word).unwrap_or(TK_ID)
636        } else {
637            TK_ID
638        };
639        (Some((word, tt)), i)
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::Tokenizer;
646    use crate::dialect::TokenType;
647    use crate::lexer::sql::Error;
648    use crate::lexer::Scanner;
649
650    #[test]
651    fn fallible_iterator() -> Result<(), Error> {
652        let tokenizer = Tokenizer::new();
653        let input = b"PRAGMA parser_trace=ON;";
654        let mut s = Scanner::new(tokenizer);
655        expect_token(&mut s, input, b"PRAGMA", TokenType::TK_PRAGMA)?;
656        expect_token(&mut s, input, b"parser_trace", TokenType::TK_ID)?;
657        Ok(())
658    }
659
660    #[test]
661    fn invalid_number_literal() -> Result<(), Error> {
662        let tokenizer = Tokenizer::new();
663        let input = b"SELECT 1E;";
664        let mut s = Scanner::new(tokenizer);
665        expect_token(&mut s, input, b"SELECT", TokenType::TK_SELECT)?;
666        let err = s.scan(input).unwrap_err();
667        assert!(matches!(err, Error::BadNumber(_)));
668        Ok(())
669    }
670
671    fn expect_token(
672        s: &mut Scanner<Tokenizer>,
673        input: &[u8],
674        token: &[u8],
675        token_type: TokenType,
676    ) -> Result<(), Error> {
677        let (t, tt) = s.scan(input)?.1.unwrap();
678        assert_eq!(token, t);
679        assert_eq!(token_type, tt);
680        Ok(())
681    }
682}