1use std::fmt;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum Token {
9 Select,
11 From,
12 Where,
13 Group,
14 By,
15 Having,
16 OrderBy,
17 Limit,
18 Offset,
19 Join,
20 Inner,
21 Left,
22 Right,
23 Full,
24 On,
25 As,
26 And,
27 Or,
28 Not,
29 Like,
30 In,
31 Between,
32
33 Count,
35 Sum,
36 Avg,
37 Min,
38 Max,
39
40 Eq, Ne, Lt, Le, Gt, Ge, Integer(i64),
50 Float(f64),
51 String(String),
52 Boolean(bool),
53 Null,
54
55 Identifier(String),
57
58 Asterisk, Comma, LeftParen, RightParen, Asc,
66 Desc,
67
68 Eof,
70}
71
72impl fmt::Display for Token {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 match self {
75 Token::Select => write!(f, "SELECT"),
76 Token::From => write!(f, "FROM"),
77 Token::Where => write!(f, "WHERE"),
78 Token::Group => write!(f, "GROUP"),
79 Token::By => write!(f, "BY"),
80 Token::Having => write!(f, "HAVING"),
81 Token::OrderBy => write!(f, "ORDER BY"),
82 Token::Limit => write!(f, "LIMIT"),
83 Token::Offset => write!(f, "OFFSET"),
84 Token::Join => write!(f, "JOIN"),
85 Token::Inner => write!(f, "INNER"),
86 Token::Left => write!(f, "LEFT"),
87 Token::Right => write!(f, "RIGHT"),
88 Token::Full => write!(f, "FULL"),
89 Token::On => write!(f, "ON"),
90 Token::As => write!(f, "AS"),
91 Token::And => write!(f, "AND"),
92 Token::Or => write!(f, "OR"),
93 Token::Not => write!(f, "NOT"),
94 Token::Like => write!(f, "LIKE"),
95 Token::In => write!(f, "IN"),
96 Token::Between => write!(f, "BETWEEN"),
97 Token::Count => write!(f, "COUNT"),
98 Token::Sum => write!(f, "SUM"),
99 Token::Avg => write!(f, "AVG"),
100 Token::Min => write!(f, "MIN"),
101 Token::Max => write!(f, "MAX"),
102 Token::Eq => write!(f, "="),
103 Token::Ne => write!(f, "!="),
104 Token::Lt => write!(f, "<"),
105 Token::Le => write!(f, "<="),
106 Token::Gt => write!(f, ">"),
107 Token::Ge => write!(f, ">="),
108 Token::Integer(i) => write!(f, "{}", i),
109 Token::Float(fl) => write!(f, "{}", fl),
110 Token::String(s) => write!(f, "'{}'", s),
111 Token::Boolean(b) => write!(f, "{}", b),
112 Token::Null => write!(f, "NULL"),
113 Token::Identifier(id) => write!(f, "{}", id),
114 Token::Asterisk => write!(f, "*"),
115 Token::Comma => write!(f, ","),
116 Token::LeftParen => write!(f, "("),
117 Token::RightParen => write!(f, ")"),
118 Token::Asc => write!(f, "ASC"),
119 Token::Desc => write!(f, "DESC"),
120 Token::Eof => write!(f, "EOF"),
121 }
122 }
123}
124
125pub struct Lexer {
127 input: Vec<char>,
128 position: usize,
129}
130
131impl Lexer {
132 pub fn new(input: &str) -> Self {
134 Self {
135 input: input.chars().collect(),
136 position: 0,
137 }
138 }
139
140 pub fn next_token(&mut self) -> Result<Token, LexerError> {
142 self.skip_whitespace();
143
144 if self.position >= self.input.len() {
145 return Ok(Token::Eof);
146 }
147
148 let ch = self.current_char();
149
150 match ch {
152 '*' => {
153 self.advance();
154 return Ok(Token::Asterisk);
155 }
156 ',' => {
157 self.advance();
158 return Ok(Token::Comma);
159 }
160 '(' => {
161 self.advance();
162 return Ok(Token::LeftParen);
163 }
164 ')' => {
165 self.advance();
166 return Ok(Token::RightParen);
167 }
168 '=' => {
169 self.advance();
170 return Ok(Token::Eq);
171 }
172 '<' => {
173 self.advance();
174 if self.position < self.input.len() && self.current_char() == '=' {
175 self.advance();
176 return Ok(Token::Le);
177 }
178 return Ok(Token::Lt);
179 }
180 '>' => {
181 self.advance();
182 if self.position < self.input.len() && self.current_char() == '=' {
183 self.advance();
184 return Ok(Token::Ge);
185 }
186 return Ok(Token::Gt);
187 }
188 '!' => {
189 self.advance();
190 if self.position < self.input.len() && self.current_char() == '=' {
191 self.advance();
192 return Ok(Token::Ne);
193 }
194 return Err(LexerError::UnexpectedCharacter(ch));
195 }
196 '\'' => return self.read_string(),
197 _ => {}
198 }
199
200 if ch.is_ascii_digit() {
202 return self.read_number();
203 }
204
205 if ch.is_alphabetic() || ch == '_' {
207 return self.read_identifier_or_keyword();
208 }
209
210 Err(LexerError::UnexpectedCharacter(ch))
211 }
212
213 pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
215 let mut tokens = Vec::new();
216 loop {
217 let token = self.next_token()?;
218 if token == Token::Eof {
219 tokens.push(token);
220 break;
221 }
222 tokens.push(token);
223 }
224 Ok(tokens)
225 }
226
227 fn current_char(&self) -> char {
228 self.input[self.position]
229 }
230
231 fn peek_char(&self) -> Option<char> {
232 if self.position + 1 < self.input.len() {
233 Some(self.input[self.position + 1])
234 } else {
235 None
236 }
237 }
238
239 fn advance(&mut self) {
240 self.position += 1;
241 }
242
243 fn skip_whitespace(&mut self) {
244 while self.position < self.input.len() && self.current_char().is_whitespace() {
245 self.advance();
246 }
247 }
248
249 fn read_number(&mut self) -> Result<Token, LexerError> {
250 let start = self.position;
251 let mut has_dot = false;
252
253 while self.position < self.input.len() {
254 let ch = self.current_char();
255 if ch.is_ascii_digit() {
256 self.advance();
257 } else if ch == '.' && !has_dot && self.peek_char().is_some_and(|c| c.is_ascii_digit())
258 {
259 has_dot = true;
260 self.advance();
261 } else {
262 break;
263 }
264 }
265
266 let num_str: String = self.input[start..self.position].iter().collect();
267
268 if has_dot {
269 num_str
270 .parse::<f64>()
271 .map(Token::Float)
272 .map_err(|_| LexerError::InvalidNumber(num_str))
273 } else {
274 num_str
275 .parse::<i64>()
276 .map(Token::Integer)
277 .map_err(|_| LexerError::InvalidNumber(num_str))
278 }
279 }
280
281 fn read_string(&mut self) -> Result<Token, LexerError> {
282 self.advance(); let start = self.position;
284
285 while self.position < self.input.len() && self.current_char() != '\'' {
286 self.advance();
287 }
288
289 if self.position >= self.input.len() {
290 return Err(LexerError::UnterminatedString);
291 }
292
293 let string: String = self.input[start..self.position].iter().collect();
294 self.advance(); Ok(Token::String(string))
297 }
298
299 fn read_identifier_or_keyword(&mut self) -> Result<Token, LexerError> {
300 let start = self.position;
301
302 while self.position < self.input.len() {
303 let ch = self.current_char();
304 if ch.is_alphanumeric() || ch == '_' || ch == '.' {
305 self.advance();
306 } else {
307 break;
308 }
309 }
310
311 let text: String = self.input[start..self.position].iter().collect();
312 let uppercase = text.to_uppercase();
313
314 if uppercase == "ORDER" {
316 self.skip_whitespace();
317 if self.position < self.input.len() {
318 let next_start = self.position;
319 let mut next_text = String::new();
320 while self.position < self.input.len() {
321 let ch = self.current_char();
322 if ch.is_alphabetic() {
323 next_text.push(ch);
324 self.advance();
325 } else {
326 break;
327 }
328 }
329 if next_text.to_uppercase() == "BY" {
330 return Ok(Token::OrderBy);
331 }
332 self.position = next_start;
334 }
335 }
336
337 let token = match uppercase.as_str() {
339 "SELECT" => Token::Select,
340 "FROM" => Token::From,
341 "WHERE" => Token::Where,
342 "GROUP" => Token::Group,
343 "BY" => Token::By,
344 "HAVING" => Token::Having,
345 "LIMIT" => Token::Limit,
346 "OFFSET" => Token::Offset,
347 "JOIN" => Token::Join,
348 "INNER" => Token::Inner,
349 "LEFT" => Token::Left,
350 "RIGHT" => Token::Right,
351 "FULL" => Token::Full,
352 "ON" => Token::On,
353 "AS" => Token::As,
354 "AND" => Token::And,
355 "OR" => Token::Or,
356 "NOT" => Token::Not,
357 "LIKE" => Token::Like,
358 "IN" => Token::In,
359 "BETWEEN" => Token::Between,
360 "COUNT" => Token::Count,
361 "SUM" => Token::Sum,
362 "AVG" => Token::Avg,
363 "MIN" => Token::Min,
364 "MAX" => Token::Max,
365 "ASC" => Token::Asc,
366 "DESC" => Token::Desc,
367 "TRUE" => Token::Boolean(true),
368 "FALSE" => Token::Boolean(false),
369 "NULL" => Token::Null,
370 _ => Token::Identifier(text),
371 };
372
373 Ok(token)
374 }
375}
376
377#[derive(Debug, Clone, PartialEq)]
379pub enum LexerError {
380 UnexpectedCharacter(char),
381 InvalidNumber(String),
382 UnterminatedString,
383}
384
385impl fmt::Display for LexerError {
386 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387 match self {
388 LexerError::UnexpectedCharacter(ch) => write!(f, "Unexpected character: '{}'", ch),
389 LexerError::InvalidNumber(s) => write!(f, "Invalid number: '{}'", s),
390 LexerError::UnterminatedString => write!(f, "Unterminated string literal"),
391 }
392 }
393}
394
395impl std::error::Error for LexerError {}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_simple_select() {
403 let mut lexer = Lexer::new("SELECT * FROM users");
404 let tokens = lexer.tokenize().unwrap();
405
406 assert_eq!(
407 tokens,
408 vec![
409 Token::Select,
410 Token::Asterisk,
411 Token::From,
412 Token::Identifier("users".to_string()),
413 Token::Eof,
414 ]
415 );
416 }
417
418 #[test]
419 fn test_select_with_where() {
420 let mut lexer = Lexer::new("SELECT name FROM users WHERE age > 18");
421 let tokens = lexer.tokenize().unwrap();
422
423 assert_eq!(tokens[0], Token::Select);
424 assert_eq!(tokens[1], Token::Identifier("name".to_string()));
425 assert_eq!(tokens[2], Token::From);
426 assert_eq!(tokens[3], Token::Identifier("users".to_string()));
427 assert_eq!(tokens[4], Token::Where);
428 assert_eq!(tokens[5], Token::Identifier("age".to_string()));
429 assert_eq!(tokens[6], Token::Gt);
430 assert_eq!(tokens[7], Token::Integer(18));
431 }
432
433 #[test]
434 fn test_string_literals() {
435 let mut lexer = Lexer::new("SELECT * FROM users WHERE name = 'John'");
436 let tokens = lexer.tokenize().unwrap();
437
438 assert!(tokens.contains(&Token::String("John".to_string())));
439 }
440
441 #[test]
442 fn test_order_by() {
443 let mut lexer = Lexer::new("SELECT * FROM users ORDER BY name ASC");
444 let tokens = lexer.tokenize().unwrap();
445
446 assert!(tokens.contains(&Token::OrderBy));
447 assert!(tokens.contains(&Token::Asc));
448 }
449
450 #[test]
451 fn test_operators() {
452 let mut lexer = Lexer::new("= != < <= > >=");
453 let tokens = lexer.tokenize().unwrap();
454
455 assert_eq!(
456 tokens,
457 vec![
458 Token::Eq,
459 Token::Ne,
460 Token::Lt,
461 Token::Le,
462 Token::Gt,
463 Token::Ge,
464 Token::Eof,
465 ]
466 );
467 }
468
469 #[test]
470 fn test_numbers() {
471 let mut lexer = Lexer::new("42 3.5");
472 let tokens = lexer.tokenize().unwrap();
473
474 assert_eq!(
475 tokens,
476 vec![Token::Integer(42), Token::Float(3.5), Token::Eof,]
477 );
478 }
479}