sochdb_query/sql/
lexer.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! SQL Lexer
16//!
17//! Converts SQL text into a stream of tokens.
18//! Handles string literals, numbers, identifiers, keywords, and operators.
19
20use super::token::{Span, Token, TokenKind};
21use std::iter::Peekable;
22use std::str::Chars;
23
24/// SQL Lexer errors
25#[derive(Debug, Clone, PartialEq)]
26pub struct LexError {
27    pub message: String,
28    pub span: Span,
29}
30
31impl LexError {
32    pub fn new(message: impl Into<String>, span: Span) -> Self {
33        Self {
34            message: message.into(),
35            span,
36        }
37    }
38}
39
40impl std::fmt::Display for LexError {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(
43            f,
44            "Lexer error at line {}, column {}: {}",
45            self.span.line, self.span.column, self.message
46        )
47    }
48}
49
50impl std::error::Error for LexError {}
51
52/// SQL Lexer - tokenizes SQL input
53pub struct Lexer<'a> {
54    input: &'a str,
55    chars: Peekable<Chars<'a>>,
56    pos: usize,
57    line: usize,
58    column: usize,
59    tokens: Vec<Token>,
60    errors: Vec<LexError>,
61    /// Counter for `?` style placeholders (auto-incrementing)
62    placeholder_counter: u32,
63}
64
65impl<'a> Lexer<'a> {
66    /// Create a new lexer for the given SQL input
67    pub fn new(input: &'a str) -> Self {
68        Self {
69            input,
70            chars: input.chars().peekable(),
71            pos: 0,
72            line: 1,
73            column: 1,
74            tokens: Vec::new(),
75            errors: Vec::new(),
76            placeholder_counter: 0,
77        }
78    }
79
80    /// Tokenize the entire input
81    pub fn tokenize(mut self) -> Result<Vec<Token>, Vec<LexError>> {
82        while !self.is_at_end() {
83            self.scan_token();
84        }
85
86        // Add EOF token
87        self.tokens.push(Token::new(
88            TokenKind::Eof,
89            Span::new(self.pos, self.pos, self.line, self.column),
90            "",
91        ));
92
93        if self.errors.is_empty() {
94            Ok(self.tokens)
95        } else {
96            Err(self.errors)
97        }
98    }
99
100    fn is_at_end(&mut self) -> bool {
101        self.chars.peek().is_none()
102    }
103
104    fn advance(&mut self) -> Option<char> {
105        let c = self.chars.next()?;
106        self.pos += c.len_utf8();
107        if c == '\n' {
108            self.line += 1;
109            self.column = 1;
110        } else {
111            self.column += 1;
112        }
113        Some(c)
114    }
115
116    fn peek(&mut self) -> Option<char> {
117        self.chars.peek().copied()
118    }
119
120    fn peek_next(&self) -> Option<char> {
121        let mut chars = self.chars.clone();
122        chars.next();
123        chars.next()
124    }
125
126    fn make_span(&self, start: usize, start_line: usize, start_col: usize) -> Span {
127        Span::new(start, self.pos, start_line, start_col)
128    }
129
130    fn scan_token(&mut self) {
131        let start = self.pos;
132        let start_line = self.line;
133        let start_col = self.column;
134
135        let c = match self.advance() {
136            Some(c) => c,
137            None => return,
138        };
139
140        match c {
141            // Whitespace
142            ' ' | '\t' | '\r' | '\n' => {
143                // Skip whitespace, don't emit token
144            }
145
146            // Single-character tokens
147            '(' => self.add_token(TokenKind::LParen, start, start_line, start_col),
148            ')' => self.add_token(TokenKind::RParen, start, start_line, start_col),
149            '[' => self.add_token(TokenKind::LBracket, start, start_line, start_col),
150            ']' => self.add_token(TokenKind::RBracket, start, start_line, start_col),
151            ',' => self.add_token(TokenKind::Comma, start, start_line, start_col),
152            ';' => self.add_token(TokenKind::Semicolon, start, start_line, start_col),
153            '+' => self.add_token(TokenKind::Plus, start, start_line, start_col),
154            '*' => self.add_token(TokenKind::Star, start, start_line, start_col),
155            '/' => {
156                if self.peek() == Some('/') || self.peek() == Some('*') {
157                    self.scan_comment(start, start_line, start_col);
158                } else {
159                    self.add_token(TokenKind::Slash, start, start_line, start_col);
160                }
161            }
162            '%' => self.add_token(TokenKind::Percent, start, start_line, start_col),
163            '&' => self.add_token(TokenKind::BitAnd, start, start_line, start_col),
164            '~' => self.add_token(TokenKind::BitNot, start, start_line, start_col),
165            '?' => {
166                // Auto-incrementing placeholder for JDBC/ODBC style ?
167                self.placeholder_counter += 1;
168                let span = self.make_span(start, start_line, start_col);
169                self.tokens.push(Token::new(
170                    TokenKind::Placeholder(self.placeholder_counter),
171                    span,
172                    "?",
173                ));
174            }
175            '@' => self.add_token(TokenKind::At, start, start_line, start_col),
176
177            // Two-character tokens
178            '-' => {
179                if self.peek() == Some('-') {
180                    // Line comment
181                    self.scan_line_comment(start, start_line, start_col);
182                } else if self.peek() == Some('>') {
183                    self.advance();
184                    if self.peek() == Some('>') {
185                        self.advance();
186                        self.add_token(TokenKind::DoubleArrow, start, start_line, start_col);
187                    } else {
188                        self.add_token(TokenKind::Arrow, start, start_line, start_col);
189                    }
190                } else {
191                    self.add_token(TokenKind::Minus, start, start_line, start_col);
192                }
193            }
194
195            '=' => self.add_token(TokenKind::Eq, start, start_line, start_col),
196
197            '!' => {
198                if self.peek() == Some('=') {
199                    self.advance();
200                    self.add_token(TokenKind::Ne, start, start_line, start_col);
201                } else {
202                    self.add_error("Unexpected character '!'", start, start_line, start_col);
203                }
204            }
205
206            '<' => {
207                if self.peek() == Some('=') {
208                    self.advance();
209                    self.add_token(TokenKind::Le, start, start_line, start_col);
210                } else if self.peek() == Some('>') {
211                    self.advance();
212                    self.add_token(TokenKind::Ne, start, start_line, start_col);
213                } else if self.peek() == Some('<') {
214                    self.advance();
215                    self.add_token(TokenKind::LeftShift, start, start_line, start_col);
216                } else {
217                    self.add_token(TokenKind::Lt, start, start_line, start_col);
218                }
219            }
220
221            '>' => {
222                if self.peek() == Some('=') {
223                    self.advance();
224                    self.add_token(TokenKind::Ge, start, start_line, start_col);
225                } else if self.peek() == Some('>') {
226                    self.advance();
227                    self.add_token(TokenKind::RightShift, start, start_line, start_col);
228                } else {
229                    self.add_token(TokenKind::Gt, start, start_line, start_col);
230                }
231            }
232
233            '|' => {
234                if self.peek() == Some('|') {
235                    self.advance();
236                    self.add_token(TokenKind::Concat, start, start_line, start_col);
237                } else {
238                    self.add_token(TokenKind::BitOr, start, start_line, start_col);
239                }
240            }
241
242            ':' => {
243                if self.peek() == Some(':') {
244                    self.advance();
245                    self.add_token(TokenKind::DoubleColon, start, start_line, start_col);
246                } else {
247                    self.add_token(TokenKind::Colon, start, start_line, start_col);
248                }
249            }
250
251            '.' => {
252                if self.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
253                    self.scan_number(start, start_line, start_col, true);
254                } else {
255                    self.add_token(TokenKind::Dot, start, start_line, start_col);
256                }
257            }
258
259            // String literals
260            '\'' => self.scan_string(start, start_line, start_col, '\''),
261            '"' => self.scan_quoted_identifier(start, start_line, start_col, '"'),
262            '`' => self.scan_quoted_identifier(start, start_line, start_col, '`'),
263
264            // Blob literal (X'...')
265            'X' | 'x' if self.peek() == Some('\'') => {
266                self.advance();
267                self.scan_blob(start, start_line, start_col);
268            }
269
270            // Numbers
271            '0'..='9' => self.scan_number(start, start_line, start_col, false),
272
273            // Identifiers and keywords
274            'a'..='z' | 'A'..='Z' | '_' => self.scan_identifier(start, start_line, start_col),
275
276            // Placeholder ($1, $2, ...)
277            '$' => self.scan_placeholder(start, start_line, start_col),
278
279            _ => {
280                self.add_error(
281                    format!("Unexpected character '{}'", c),
282                    start,
283                    start_line,
284                    start_col,
285                );
286            }
287        }
288    }
289
290    fn scan_string(&mut self, start: usize, start_line: usize, start_col: usize, quote: char) {
291        let mut value = String::new();
292
293        while let Some(c) = self.peek() {
294            if c == quote {
295                self.advance();
296                // Check for escaped quote ('')
297                if self.peek() == Some(quote) {
298                    self.advance();
299                    value.push(quote);
300                } else {
301                    // End of string
302                    let span = self.make_span(start, start_line, start_col);
303                    self.tokens
304                        .push(Token::new(TokenKind::String(value), span, ""));
305                    return;
306                }
307            } else if c == '\\' {
308                self.advance();
309                // Handle escape sequences
310                if let Some(escaped) = self.advance() {
311                    match escaped {
312                        'n' => value.push('\n'),
313                        'r' => value.push('\r'),
314                        't' => value.push('\t'),
315                        '\\' => value.push('\\'),
316                        '\'' => value.push('\''),
317                        '"' => value.push('"'),
318                        '0' => value.push('\0'),
319                        _ => {
320                            value.push('\\');
321                            value.push(escaped);
322                        }
323                    }
324                }
325            } else {
326                self.advance();
327                value.push(c);
328            }
329        }
330
331        self.add_error("Unterminated string literal", start, start_line, start_col);
332    }
333
334    fn scan_quoted_identifier(
335        &mut self,
336        start: usize,
337        start_line: usize,
338        start_col: usize,
339        quote: char,
340    ) {
341        let mut value = String::new();
342
343        while let Some(c) = self.peek() {
344            if c == quote {
345                self.advance();
346                // Check for escaped quote
347                if self.peek() == Some(quote) {
348                    self.advance();
349                    value.push(quote);
350                } else {
351                    let span = self.make_span(start, start_line, start_col);
352                    self.tokens
353                        .push(Token::new(TokenKind::QuotedIdentifier(value), span, ""));
354                    return;
355                }
356            } else {
357                self.advance();
358                value.push(c);
359            }
360        }
361
362        self.add_error(
363            "Unterminated quoted identifier",
364            start,
365            start_line,
366            start_col,
367        );
368    }
369
370    fn scan_number(
371        &mut self,
372        start: usize,
373        start_line: usize,
374        start_col: usize,
375        started_with_dot: bool,
376    ) {
377        let num_start = start;
378        let mut has_dot = started_with_dot;
379        let mut has_exp = false;
380
381        // Consume integer part
382        while let Some(c) = self.peek() {
383            if c.is_ascii_digit() {
384                self.advance();
385            } else if c == '.' && !has_dot && !has_exp {
386                // Check it's not a range operator (..)
387                if self.peek_next() == Some('.') {
388                    break;
389                }
390                has_dot = true;
391                self.advance();
392            } else if (c == 'e' || c == 'E') && !has_exp {
393                has_exp = true;
394                self.advance();
395                // Optional sign
396                if self.peek() == Some('+') || self.peek() == Some('-') {
397                    self.advance();
398                }
399            } else {
400                break;
401            }
402        }
403
404        let literal = &self.input[num_start..self.pos];
405        let span = self.make_span(start, start_line, start_col);
406
407        if has_dot || has_exp {
408            match literal.parse::<f64>() {
409                Ok(n) => self
410                    .tokens
411                    .push(Token::new(TokenKind::Float(n), span, literal)),
412                Err(_) => self.add_error("Invalid float literal", start, start_line, start_col),
413            }
414        } else {
415            match literal.parse::<i64>() {
416                Ok(n) => self
417                    .tokens
418                    .push(Token::new(TokenKind::Integer(n), span, literal)),
419                Err(_) => self.add_error("Invalid integer literal", start, start_line, start_col),
420            }
421        }
422    }
423
424    fn scan_identifier(&mut self, start: usize, start_line: usize, start_col: usize) {
425        while let Some(c) = self.peek() {
426            if c.is_ascii_alphanumeric() || c == '_' {
427                self.advance();
428            } else {
429                break;
430            }
431        }
432
433        let literal = &self.input[start..self.pos];
434        let span = self.make_span(start, start_line, start_col);
435
436        // Check for keyword
437        let kind = TokenKind::from_keyword(literal)
438            .unwrap_or_else(|| TokenKind::Identifier(literal.to_string()));
439
440        self.tokens.push(Token::new(kind, span, literal));
441    }
442
443    fn scan_placeholder(&mut self, start: usize, start_line: usize, start_col: usize) {
444        let mut num = String::new();
445
446        while let Some(c) = self.peek() {
447            if c.is_ascii_digit() {
448                self.advance();
449                num.push(c);
450            } else {
451                break;
452            }
453        }
454
455        let span = self.make_span(start, start_line, start_col);
456
457        if num.is_empty() {
458            self.add_error("Expected number after $", start, start_line, start_col);
459        } else if let Ok(n) = num.parse::<u32>() {
460            self.tokens.push(Token::new(
461                TokenKind::Placeholder(n),
462                span,
463                &self.input[start..self.pos],
464            ));
465        } else {
466            self.add_error("Invalid placeholder number", start, start_line, start_col);
467        }
468    }
469
470    fn scan_comment(&mut self, start: usize, start_line: usize, start_col: usize) {
471        self.advance(); // consume second / or *
472
473        if self.peek() == Some('*') || self.input[start..self.pos].ends_with('*') {
474            // Block comment /* ... */
475            let mut depth = 1;
476
477            while depth > 0 && !self.is_at_end() {
478                let c = self.peek();
479                let next = self.peek_next();
480
481                if c == Some('*') && next == Some('/') {
482                    self.advance();
483                    self.advance();
484                    depth -= 1;
485                } else if c == Some('/') && next == Some('*') {
486                    self.advance();
487                    self.advance();
488                    depth += 1;
489                } else {
490                    self.advance();
491                }
492            }
493
494            if depth > 0 {
495                self.add_error("Unterminated block comment", start, start_line, start_col);
496            }
497        } else {
498            // Line comment //
499            while let Some(c) = self.peek() {
500                if c == '\n' {
501                    break;
502                }
503                self.advance();
504            }
505        }
506        // Don't emit comment tokens
507    }
508
509    fn scan_line_comment(&mut self, _start: usize, _start_line: usize, _start_col: usize) {
510        self.advance(); // consume second -
511
512        while let Some(c) = self.peek() {
513            if c == '\n' {
514                break;
515            }
516            self.advance();
517        }
518        // Don't emit comment tokens
519    }
520
521    fn scan_blob(&mut self, start: usize, start_line: usize, start_col: usize) {
522        let mut hex = String::new();
523
524        while let Some(c) = self.peek() {
525            if c == '\'' {
526                self.advance();
527                break;
528            } else if c.is_ascii_hexdigit() {
529                self.advance();
530                hex.push(c);
531            } else if c.is_whitespace() {
532                self.advance(); // Allow whitespace in blob
533            } else {
534                self.add_error(
535                    "Invalid hex digit in blob literal",
536                    start,
537                    start_line,
538                    start_col,
539                );
540                return;
541            }
542        }
543
544        if !hex.len().is_multiple_of(2) {
545            self.add_error(
546                "Blob literal must have even number of hex digits",
547                start,
548                start_line,
549                start_col,
550            );
551            return;
552        }
553
554        let bytes: Result<Vec<u8>, _> = (0..hex.len())
555            .step_by(2)
556            .map(|i| u8::from_str_radix(&hex[i..i + 2], 16))
557            .collect();
558
559        match bytes {
560            Ok(data) => {
561                let span = self.make_span(start, start_line, start_col);
562                self.tokens
563                    .push(Token::new(TokenKind::Blob(data), span, ""));
564            }
565            Err(_) => {
566                self.add_error("Invalid blob literal", start, start_line, start_col);
567            }
568        }
569    }
570
571    fn add_token(&mut self, kind: TokenKind, start: usize, start_line: usize, start_col: usize) {
572        let span = self.make_span(start, start_line, start_col);
573        let literal = &self.input[start..self.pos];
574        self.tokens.push(Token::new(kind, span, literal));
575    }
576
577    fn add_error(
578        &mut self,
579        message: impl Into<String>,
580        start: usize,
581        start_line: usize,
582        start_col: usize,
583    ) {
584        let span = self.make_span(start, start_line, start_col);
585        self.errors.push(LexError::new(message, span));
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[test]
594    fn test_simple_select() {
595        let tokens = Lexer::new("SELECT * FROM users").tokenize().unwrap();
596        assert_eq!(tokens.len(), 5); // SELECT, *, FROM, users, EOF
597        assert_eq!(tokens[0].kind, TokenKind::Select);
598        assert_eq!(tokens[1].kind, TokenKind::Star);
599        assert_eq!(tokens[2].kind, TokenKind::From);
600        assert!(matches!(tokens[3].kind, TokenKind::Identifier(_)));
601    }
602
603    #[test]
604    fn test_string_literal() {
605        let tokens = Lexer::new("SELECT 'hello''world'").tokenize().unwrap();
606        assert!(matches!(&tokens[1].kind, TokenKind::String(s) if s == "hello'world"));
607    }
608
609    #[test]
610    #[allow(clippy::approx_constant)]
611    fn test_numbers() {
612        let tokens = Lexer::new("42 3.14 1e10 .5").tokenize().unwrap();
613        assert!(matches!(tokens[0].kind, TokenKind::Integer(42)));
614        assert!(matches!(tokens[1].kind, TokenKind::Float(f) if (f - 3.14).abs() < 0.001));
615        assert!(matches!(tokens[2].kind, TokenKind::Float(_)));
616        assert!(matches!(tokens[3].kind, TokenKind::Float(f) if (f - 0.5).abs() < 0.001));
617    }
618
619    #[test]
620    fn test_operators() {
621        let tokens = Lexer::new("= != <> < <= > >= || ->").tokenize().unwrap();
622        assert_eq!(tokens[0].kind, TokenKind::Eq);
623        assert_eq!(tokens[1].kind, TokenKind::Ne);
624        assert_eq!(tokens[2].kind, TokenKind::Ne);
625        assert_eq!(tokens[3].kind, TokenKind::Lt);
626        assert_eq!(tokens[4].kind, TokenKind::Le);
627        assert_eq!(tokens[5].kind, TokenKind::Gt);
628        assert_eq!(tokens[6].kind, TokenKind::Ge);
629        assert_eq!(tokens[7].kind, TokenKind::Concat);
630        assert_eq!(tokens[8].kind, TokenKind::Arrow);
631    }
632
633    #[test]
634    fn test_keywords() {
635        let tokens = Lexer::new("SELECT INSERT UPDATE DELETE FROM WHERE")
636            .tokenize()
637            .unwrap();
638        assert_eq!(tokens[0].kind, TokenKind::Select);
639        assert_eq!(tokens[1].kind, TokenKind::Insert);
640        assert_eq!(tokens[2].kind, TokenKind::Update);
641        assert_eq!(tokens[3].kind, TokenKind::Delete);
642        assert_eq!(tokens[4].kind, TokenKind::From);
643        assert_eq!(tokens[5].kind, TokenKind::Where);
644    }
645
646    #[test]
647    fn test_placeholder() {
648        let tokens = Lexer::new("$1 $2 $10").tokenize().unwrap();
649        assert!(matches!(tokens[0].kind, TokenKind::Placeholder(1)));
650        assert!(matches!(tokens[1].kind, TokenKind::Placeholder(2)));
651        assert!(matches!(tokens[2].kind, TokenKind::Placeholder(10)));
652    }
653
654    #[test]
655    fn test_line_comment() {
656        let tokens = Lexer::new("SELECT -- comment\n* FROM users")
657            .tokenize()
658            .unwrap();
659        assert_eq!(tokens.len(), 5); // SELECT, *, FROM, users, EOF
660        assert_eq!(tokens[0].kind, TokenKind::Select);
661        assert_eq!(tokens[1].kind, TokenKind::Star);
662    }
663
664    #[test]
665    fn test_blob_literal() {
666        let tokens = Lexer::new("X'48454C4C4F'").tokenize().unwrap();
667        assert!(matches!(&tokens[0].kind, TokenKind::Blob(b) if b == b"HELLO"));
668    }
669}