rustlite_core/query/
lexer.rs

1/// Lexer for tokenizing SQL-like queries
2///
3/// Converts raw SQL text into a stream of tokens for parsing.
4use std::fmt;
5
6/// Token types produced by the lexer
7#[derive(Debug, Clone, PartialEq)]
8pub enum Token {
9    // Keywords
10    Select,
11    From,
12    Where,
13    Group,
14    By,
15    Having,
16    OrderBy,
17    Limit,
18    Offset,
19    Join,
20    Inner,
21    Left,
22    Right,
23    Full,
24    On,
25    As,
26    And,
27    Or,
28    Not,
29    Like,
30    In,
31    Between,
32
33    // Aggregate functions
34    Count,
35    Sum,
36    Avg,
37    Min,
38    Max,
39
40    // Operators
41    Eq, // =
42    Ne, // !=
43    Lt, // <
44    Le, // <=
45    Gt, // >
46    Ge, // >=
47
48    // Literals
49    Integer(i64),
50    Float(f64),
51    String(String),
52    Boolean(bool),
53    Null,
54
55    // Identifiers
56    Identifier(String),
57
58    // Punctuation
59    Asterisk,   // *
60    Comma,      // ,
61    LeftParen,  // (
62    RightParen, // )
63
64    // Special
65    Asc,
66    Desc,
67
68    // End of input
69    Eof,
70}
71
72impl fmt::Display for Token {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        match self {
75            Token::Select => write!(f, "SELECT"),
76            Token::From => write!(f, "FROM"),
77            Token::Where => write!(f, "WHERE"),
78            Token::Group => write!(f, "GROUP"),
79            Token::By => write!(f, "BY"),
80            Token::Having => write!(f, "HAVING"),
81            Token::OrderBy => write!(f, "ORDER BY"),
82            Token::Limit => write!(f, "LIMIT"),
83            Token::Offset => write!(f, "OFFSET"),
84            Token::Join => write!(f, "JOIN"),
85            Token::Inner => write!(f, "INNER"),
86            Token::Left => write!(f, "LEFT"),
87            Token::Right => write!(f, "RIGHT"),
88            Token::Full => write!(f, "FULL"),
89            Token::On => write!(f, "ON"),
90            Token::As => write!(f, "AS"),
91            Token::And => write!(f, "AND"),
92            Token::Or => write!(f, "OR"),
93            Token::Not => write!(f, "NOT"),
94            Token::Like => write!(f, "LIKE"),
95            Token::In => write!(f, "IN"),
96            Token::Between => write!(f, "BETWEEN"),
97            Token::Count => write!(f, "COUNT"),
98            Token::Sum => write!(f, "SUM"),
99            Token::Avg => write!(f, "AVG"),
100            Token::Min => write!(f, "MIN"),
101            Token::Max => write!(f, "MAX"),
102            Token::Eq => write!(f, "="),
103            Token::Ne => write!(f, "!="),
104            Token::Lt => write!(f, "<"),
105            Token::Le => write!(f, "<="),
106            Token::Gt => write!(f, ">"),
107            Token::Ge => write!(f, ">="),
108            Token::Integer(i) => write!(f, "{}", i),
109            Token::Float(fl) => write!(f, "{}", fl),
110            Token::String(s) => write!(f, "'{}'", s),
111            Token::Boolean(b) => write!(f, "{}", b),
112            Token::Null => write!(f, "NULL"),
113            Token::Identifier(id) => write!(f, "{}", id),
114            Token::Asterisk => write!(f, "*"),
115            Token::Comma => write!(f, ","),
116            Token::LeftParen => write!(f, "("),
117            Token::RightParen => write!(f, ")"),
118            Token::Asc => write!(f, "ASC"),
119            Token::Desc => write!(f, "DESC"),
120            Token::Eof => write!(f, "EOF"),
121        }
122    }
123}
124
125/// Lexer state
126pub struct Lexer {
127    input: Vec<char>,
128    position: usize,
129}
130
131impl Lexer {
132    /// Create a new lexer from input string
133    pub fn new(input: &str) -> Self {
134        Self {
135            input: input.chars().collect(),
136            position: 0,
137        }
138    }
139
140    /// Get the next token
141    pub fn next_token(&mut self) -> Result<Token, LexerError> {
142        self.skip_whitespace();
143
144        if self.position >= self.input.len() {
145            return Ok(Token::Eof);
146        }
147
148        let ch = self.current_char();
149
150        // Single-character tokens
151        match ch {
152            '*' => {
153                self.advance();
154                return Ok(Token::Asterisk);
155            }
156            ',' => {
157                self.advance();
158                return Ok(Token::Comma);
159            }
160            '(' => {
161                self.advance();
162                return Ok(Token::LeftParen);
163            }
164            ')' => {
165                self.advance();
166                return Ok(Token::RightParen);
167            }
168            '=' => {
169                self.advance();
170                return Ok(Token::Eq);
171            }
172            '<' => {
173                self.advance();
174                if self.position < self.input.len() && self.current_char() == '=' {
175                    self.advance();
176                    return Ok(Token::Le);
177                }
178                return Ok(Token::Lt);
179            }
180            '>' => {
181                self.advance();
182                if self.position < self.input.len() && self.current_char() == '=' {
183                    self.advance();
184                    return Ok(Token::Ge);
185                }
186                return Ok(Token::Gt);
187            }
188            '!' => {
189                self.advance();
190                if self.position < self.input.len() && self.current_char() == '=' {
191                    self.advance();
192                    return Ok(Token::Ne);
193                }
194                return Err(LexerError::UnexpectedCharacter(ch));
195            }
196            '\'' => return self.read_string(),
197            _ => {}
198        }
199
200        // Numbers
201        if ch.is_ascii_digit() {
202            return self.read_number();
203        }
204
205        // Identifiers and keywords
206        if ch.is_alphabetic() || ch == '_' {
207            return self.read_identifier_or_keyword();
208        }
209
210        Err(LexerError::UnexpectedCharacter(ch))
211    }
212
213    /// Tokenize entire input into vector of tokens
214    pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
215        let mut tokens = Vec::new();
216        loop {
217            let token = self.next_token()?;
218            if token == Token::Eof {
219                tokens.push(token);
220                break;
221            }
222            tokens.push(token);
223        }
224        Ok(tokens)
225    }
226
227    fn current_char(&self) -> char {
228        self.input[self.position]
229    }
230
231    fn peek_char(&self) -> Option<char> {
232        if self.position + 1 < self.input.len() {
233            Some(self.input[self.position + 1])
234        } else {
235            None
236        }
237    }
238
239    fn advance(&mut self) {
240        self.position += 1;
241    }
242
243    fn skip_whitespace(&mut self) {
244        while self.position < self.input.len() && self.current_char().is_whitespace() {
245            self.advance();
246        }
247    }
248
249    fn read_number(&mut self) -> Result<Token, LexerError> {
250        let start = self.position;
251        let mut has_dot = false;
252
253        while self.position < self.input.len() {
254            let ch = self.current_char();
255            if ch.is_ascii_digit() {
256                self.advance();
257            } else if ch == '.' && !has_dot && self.peek_char().is_some_and(|c| c.is_ascii_digit())
258            {
259                has_dot = true;
260                self.advance();
261            } else {
262                break;
263            }
264        }
265
266        let num_str: String = self.input[start..self.position].iter().collect();
267
268        if has_dot {
269            num_str
270                .parse::<f64>()
271                .map(Token::Float)
272                .map_err(|_| LexerError::InvalidNumber(num_str))
273        } else {
274            num_str
275                .parse::<i64>()
276                .map(Token::Integer)
277                .map_err(|_| LexerError::InvalidNumber(num_str))
278        }
279    }
280
281    fn read_string(&mut self) -> Result<Token, LexerError> {
282        self.advance(); // skip opening quote
283        let start = self.position;
284
285        while self.position < self.input.len() && self.current_char() != '\'' {
286            self.advance();
287        }
288
289        if self.position >= self.input.len() {
290            return Err(LexerError::UnterminatedString);
291        }
292
293        let string: String = self.input[start..self.position].iter().collect();
294        self.advance(); // skip closing quote
295
296        Ok(Token::String(string))
297    }
298
299    fn read_identifier_or_keyword(&mut self) -> Result<Token, LexerError> {
300        let start = self.position;
301
302        while self.position < self.input.len() {
303            let ch = self.current_char();
304            if ch.is_alphanumeric() || ch == '_' || ch == '.' {
305                self.advance();
306            } else {
307                break;
308            }
309        }
310
311        let text: String = self.input[start..self.position].iter().collect();
312        let uppercase = text.to_uppercase();
313
314        // Check for multi-word keywords (ORDER BY)
315        if uppercase == "ORDER" {
316            self.skip_whitespace();
317            if self.position < self.input.len() {
318                let next_start = self.position;
319                let mut next_text = String::new();
320                while self.position < self.input.len() {
321                    let ch = self.current_char();
322                    if ch.is_alphabetic() {
323                        next_text.push(ch);
324                        self.advance();
325                    } else {
326                        break;
327                    }
328                }
329                if next_text.to_uppercase() == "BY" {
330                    return Ok(Token::OrderBy);
331                }
332                // Rollback if not followed by BY
333                self.position = next_start;
334            }
335        }
336
337        // Match keywords
338        let token = match uppercase.as_str() {
339            "SELECT" => Token::Select,
340            "FROM" => Token::From,
341            "WHERE" => Token::Where,
342            "GROUP" => Token::Group,
343            "BY" => Token::By,
344            "HAVING" => Token::Having,
345            "LIMIT" => Token::Limit,
346            "OFFSET" => Token::Offset,
347            "JOIN" => Token::Join,
348            "INNER" => Token::Inner,
349            "LEFT" => Token::Left,
350            "RIGHT" => Token::Right,
351            "FULL" => Token::Full,
352            "ON" => Token::On,
353            "AS" => Token::As,
354            "AND" => Token::And,
355            "OR" => Token::Or,
356            "NOT" => Token::Not,
357            "LIKE" => Token::Like,
358            "IN" => Token::In,
359            "BETWEEN" => Token::Between,
360            "COUNT" => Token::Count,
361            "SUM" => Token::Sum,
362            "AVG" => Token::Avg,
363            "MIN" => Token::Min,
364            "MAX" => Token::Max,
365            "ASC" => Token::Asc,
366            "DESC" => Token::Desc,
367            "TRUE" => Token::Boolean(true),
368            "FALSE" => Token::Boolean(false),
369            "NULL" => Token::Null,
370            _ => Token::Identifier(text),
371        };
372
373        Ok(token)
374    }
375}
376
377/// Lexer errors
378#[derive(Debug, Clone, PartialEq)]
379pub enum LexerError {
380    UnexpectedCharacter(char),
381    InvalidNumber(String),
382    UnterminatedString,
383}
384
385impl fmt::Display for LexerError {
386    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387        match self {
388            LexerError::UnexpectedCharacter(ch) => write!(f, "Unexpected character: '{}'", ch),
389            LexerError::InvalidNumber(s) => write!(f, "Invalid number: '{}'", s),
390            LexerError::UnterminatedString => write!(f, "Unterminated string literal"),
391        }
392    }
393}
394
395impl std::error::Error for LexerError {}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_simple_select() {
403        let mut lexer = Lexer::new("SELECT * FROM users");
404        let tokens = lexer.tokenize().unwrap();
405
406        assert_eq!(
407            tokens,
408            vec![
409                Token::Select,
410                Token::Asterisk,
411                Token::From,
412                Token::Identifier("users".to_string()),
413                Token::Eof,
414            ]
415        );
416    }
417
418    #[test]
419    fn test_select_with_where() {
420        let mut lexer = Lexer::new("SELECT name FROM users WHERE age > 18");
421        let tokens = lexer.tokenize().unwrap();
422
423        assert_eq!(tokens[0], Token::Select);
424        assert_eq!(tokens[1], Token::Identifier("name".to_string()));
425        assert_eq!(tokens[2], Token::From);
426        assert_eq!(tokens[3], Token::Identifier("users".to_string()));
427        assert_eq!(tokens[4], Token::Where);
428        assert_eq!(tokens[5], Token::Identifier("age".to_string()));
429        assert_eq!(tokens[6], Token::Gt);
430        assert_eq!(tokens[7], Token::Integer(18));
431    }
432
433    #[test]
434    fn test_string_literals() {
435        let mut lexer = Lexer::new("SELECT * FROM users WHERE name = 'John'");
436        let tokens = lexer.tokenize().unwrap();
437
438        assert!(tokens.contains(&Token::String("John".to_string())));
439    }
440
441    #[test]
442    fn test_order_by() {
443        let mut lexer = Lexer::new("SELECT * FROM users ORDER BY name ASC");
444        let tokens = lexer.tokenize().unwrap();
445
446        assert!(tokens.contains(&Token::OrderBy));
447        assert!(tokens.contains(&Token::Asc));
448    }
449
450    #[test]
451    fn test_operators() {
452        let mut lexer = Lexer::new("= != < <= > >=");
453        let tokens = lexer.tokenize().unwrap();
454
455        assert_eq!(
456            tokens,
457            vec![
458                Token::Eq,
459                Token::Ne,
460                Token::Lt,
461                Token::Le,
462                Token::Gt,
463                Token::Ge,
464                Token::Eof,
465            ]
466        );
467    }
468
469    #[test]
470    fn test_numbers() {
471        let mut lexer = Lexer::new("42 3.5");
472        let tokens = lexer.tokenize().unwrap();
473
474        assert_eq!(
475            tokens,
476            vec![Token::Integer(42), Token::Float(3.5), Token::Eof,]
477        );
478    }
479}