1use std::fmt;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum Token {
9 Select,
11 From,
12 Where,
13 OrderBy,
14 Limit,
15 Offset,
16 Join,
17 Inner,
18 Left,
19 Right,
20 Full,
21 On,
22 As,
23 And,
24 Or,
25 Not,
26 Like,
27 In,
28 Between,
29
30 Count,
32 Sum,
33 Avg,
34 Min,
35 Max,
36
37 Eq, Ne, Lt, Le, Gt, Ge, Integer(i64),
47 Float(f64),
48 String(String),
49 Boolean(bool),
50 Null,
51
52 Identifier(String),
54
55 Asterisk, Comma, LeftParen, RightParen, Asc,
63 Desc,
64
65 Eof,
67}
68
69impl fmt::Display for Token {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 match self {
72 Token::Select => write!(f, "SELECT"),
73 Token::From => write!(f, "FROM"),
74 Token::Where => write!(f, "WHERE"),
75 Token::OrderBy => write!(f, "ORDER BY"),
76 Token::Limit => write!(f, "LIMIT"),
77 Token::Offset => write!(f, "OFFSET"),
78 Token::Join => write!(f, "JOIN"),
79 Token::Inner => write!(f, "INNER"),
80 Token::Left => write!(f, "LEFT"),
81 Token::Right => write!(f, "RIGHT"),
82 Token::Full => write!(f, "FULL"),
83 Token::On => write!(f, "ON"),
84 Token::As => write!(f, "AS"),
85 Token::And => write!(f, "AND"),
86 Token::Or => write!(f, "OR"),
87 Token::Not => write!(f, "NOT"),
88 Token::Like => write!(f, "LIKE"),
89 Token::In => write!(f, "IN"),
90 Token::Between => write!(f, "BETWEEN"),
91 Token::Count => write!(f, "COUNT"),
92 Token::Sum => write!(f, "SUM"),
93 Token::Avg => write!(f, "AVG"),
94 Token::Min => write!(f, "MIN"),
95 Token::Max => write!(f, "MAX"),
96 Token::Eq => write!(f, "="),
97 Token::Ne => write!(f, "!="),
98 Token::Lt => write!(f, "<"),
99 Token::Le => write!(f, "<="),
100 Token::Gt => write!(f, ">"),
101 Token::Ge => write!(f, ">="),
102 Token::Integer(i) => write!(f, "{}", i),
103 Token::Float(fl) => write!(f, "{}", fl),
104 Token::String(s) => write!(f, "'{}'", s),
105 Token::Boolean(b) => write!(f, "{}", b),
106 Token::Null => write!(f, "NULL"),
107 Token::Identifier(id) => write!(f, "{}", id),
108 Token::Asterisk => write!(f, "*"),
109 Token::Comma => write!(f, ","),
110 Token::LeftParen => write!(f, "("),
111 Token::RightParen => write!(f, ")"),
112 Token::Asc => write!(f, "ASC"),
113 Token::Desc => write!(f, "DESC"),
114 Token::Eof => write!(f, "EOF"),
115 }
116 }
117}
118
119pub struct Lexer {
121 input: Vec<char>,
122 position: usize,
123}
124
125impl Lexer {
126 pub fn new(input: &str) -> Self {
128 Self {
129 input: input.chars().collect(),
130 position: 0,
131 }
132 }
133
134 pub fn next_token(&mut self) -> Result<Token, LexerError> {
136 self.skip_whitespace();
137
138 if self.position >= self.input.len() {
139 return Ok(Token::Eof);
140 }
141
142 let ch = self.current_char();
143
144 match ch {
146 '*' => {
147 self.advance();
148 return Ok(Token::Asterisk);
149 }
150 ',' => {
151 self.advance();
152 return Ok(Token::Comma);
153 }
154 '(' => {
155 self.advance();
156 return Ok(Token::LeftParen);
157 }
158 ')' => {
159 self.advance();
160 return Ok(Token::RightParen);
161 }
162 '=' => {
163 self.advance();
164 return Ok(Token::Eq);
165 }
166 '<' => {
167 self.advance();
168 if self.position < self.input.len() && self.current_char() == '=' {
169 self.advance();
170 return Ok(Token::Le);
171 }
172 return Ok(Token::Lt);
173 }
174 '>' => {
175 self.advance();
176 if self.position < self.input.len() && self.current_char() == '=' {
177 self.advance();
178 return Ok(Token::Ge);
179 }
180 return Ok(Token::Gt);
181 }
182 '!' => {
183 self.advance();
184 if self.position < self.input.len() && self.current_char() == '=' {
185 self.advance();
186 return Ok(Token::Ne);
187 }
188 return Err(LexerError::UnexpectedCharacter(ch));
189 }
190 '\'' => return self.read_string(),
191 _ => {}
192 }
193
194 if ch.is_ascii_digit() {
196 return self.read_number();
197 }
198
199 if ch.is_alphabetic() || ch == '_' {
201 return self.read_identifier_or_keyword();
202 }
203
204 Err(LexerError::UnexpectedCharacter(ch))
205 }
206
207 pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
209 let mut tokens = Vec::new();
210 loop {
211 let token = self.next_token()?;
212 if token == Token::Eof {
213 tokens.push(token);
214 break;
215 }
216 tokens.push(token);
217 }
218 Ok(tokens)
219 }
220
221 fn current_char(&self) -> char {
222 self.input[self.position]
223 }
224
225 fn peek_char(&self) -> Option<char> {
226 if self.position + 1 < self.input.len() {
227 Some(self.input[self.position + 1])
228 } else {
229 None
230 }
231 }
232
233 fn advance(&mut self) {
234 self.position += 1;
235 }
236
237 fn skip_whitespace(&mut self) {
238 while self.position < self.input.len() && self.current_char().is_whitespace() {
239 self.advance();
240 }
241 }
242
243 fn read_number(&mut self) -> Result<Token, LexerError> {
244 let start = self.position;
245 let mut has_dot = false;
246
247 while self.position < self.input.len() {
248 let ch = self.current_char();
249 if ch.is_ascii_digit() {
250 self.advance();
251 } else if ch == '.' && !has_dot && self.peek_char().is_some_and(|c| c.is_ascii_digit())
252 {
253 has_dot = true;
254 self.advance();
255 } else {
256 break;
257 }
258 }
259
260 let num_str: String = self.input[start..self.position].iter().collect();
261
262 if has_dot {
263 num_str
264 .parse::<f64>()
265 .map(Token::Float)
266 .map_err(|_| LexerError::InvalidNumber(num_str))
267 } else {
268 num_str
269 .parse::<i64>()
270 .map(Token::Integer)
271 .map_err(|_| LexerError::InvalidNumber(num_str))
272 }
273 }
274
275 fn read_string(&mut self) -> Result<Token, LexerError> {
276 self.advance(); let start = self.position;
278
279 while self.position < self.input.len() && self.current_char() != '\'' {
280 self.advance();
281 }
282
283 if self.position >= self.input.len() {
284 return Err(LexerError::UnterminatedString);
285 }
286
287 let string: String = self.input[start..self.position].iter().collect();
288 self.advance(); Ok(Token::String(string))
291 }
292
293 fn read_identifier_or_keyword(&mut self) -> Result<Token, LexerError> {
294 let start = self.position;
295
296 while self.position < self.input.len() {
297 let ch = self.current_char();
298 if ch.is_alphanumeric() || ch == '_' || ch == '.' {
299 self.advance();
300 } else {
301 break;
302 }
303 }
304
305 let text: String = self.input[start..self.position].iter().collect();
306 let uppercase = text.to_uppercase();
307
308 if uppercase == "ORDER" {
310 self.skip_whitespace();
311 if self.position < self.input.len() {
312 let next_start = self.position;
313 let mut next_text = String::new();
314 while self.position < self.input.len() {
315 let ch = self.current_char();
316 if ch.is_alphabetic() {
317 next_text.push(ch);
318 self.advance();
319 } else {
320 break;
321 }
322 }
323 if next_text.to_uppercase() == "BY" {
324 return Ok(Token::OrderBy);
325 }
326 self.position = next_start;
328 }
329 }
330
331 let token = match uppercase.as_str() {
333 "SELECT" => Token::Select,
334 "FROM" => Token::From,
335 "WHERE" => Token::Where,
336 "LIMIT" => Token::Limit,
337 "OFFSET" => Token::Offset,
338 "JOIN" => Token::Join,
339 "INNER" => Token::Inner,
340 "LEFT" => Token::Left,
341 "RIGHT" => Token::Right,
342 "FULL" => Token::Full,
343 "ON" => Token::On,
344 "AS" => Token::As,
345 "AND" => Token::And,
346 "OR" => Token::Or,
347 "NOT" => Token::Not,
348 "LIKE" => Token::Like,
349 "IN" => Token::In,
350 "BETWEEN" => Token::Between,
351 "COUNT" => Token::Count,
352 "SUM" => Token::Sum,
353 "AVG" => Token::Avg,
354 "MIN" => Token::Min,
355 "MAX" => Token::Max,
356 "ASC" => Token::Asc,
357 "DESC" => Token::Desc,
358 "TRUE" => Token::Boolean(true),
359 "FALSE" => Token::Boolean(false),
360 "NULL" => Token::Null,
361 _ => Token::Identifier(text),
362 };
363
364 Ok(token)
365 }
366}
367
368#[derive(Debug, Clone, PartialEq)]
370pub enum LexerError {
371 UnexpectedCharacter(char),
372 InvalidNumber(String),
373 UnterminatedString,
374}
375
376impl fmt::Display for LexerError {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 match self {
379 LexerError::UnexpectedCharacter(ch) => write!(f, "Unexpected character: '{}'", ch),
380 LexerError::InvalidNumber(s) => write!(f, "Invalid number: '{}'", s),
381 LexerError::UnterminatedString => write!(f, "Unterminated string literal"),
382 }
383 }
384}
385
386impl std::error::Error for LexerError {}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_simple_select() {
394 let mut lexer = Lexer::new("SELECT * FROM users");
395 let tokens = lexer.tokenize().unwrap();
396
397 assert_eq!(
398 tokens,
399 vec![
400 Token::Select,
401 Token::Asterisk,
402 Token::From,
403 Token::Identifier("users".to_string()),
404 Token::Eof,
405 ]
406 );
407 }
408
409 #[test]
410 fn test_select_with_where() {
411 let mut lexer = Lexer::new("SELECT name FROM users WHERE age > 18");
412 let tokens = lexer.tokenize().unwrap();
413
414 assert_eq!(tokens[0], Token::Select);
415 assert_eq!(tokens[1], Token::Identifier("name".to_string()));
416 assert_eq!(tokens[2], Token::From);
417 assert_eq!(tokens[3], Token::Identifier("users".to_string()));
418 assert_eq!(tokens[4], Token::Where);
419 assert_eq!(tokens[5], Token::Identifier("age".to_string()));
420 assert_eq!(tokens[6], Token::Gt);
421 assert_eq!(tokens[7], Token::Integer(18));
422 }
423
424 #[test]
425 fn test_string_literals() {
426 let mut lexer = Lexer::new("SELECT * FROM users WHERE name = 'John'");
427 let tokens = lexer.tokenize().unwrap();
428
429 assert!(tokens.contains(&Token::String("John".to_string())));
430 }
431
432 #[test]
433 fn test_order_by() {
434 let mut lexer = Lexer::new("SELECT * FROM users ORDER BY name ASC");
435 let tokens = lexer.tokenize().unwrap();
436
437 assert!(tokens.contains(&Token::OrderBy));
438 assert!(tokens.contains(&Token::Asc));
439 }
440
441 #[test]
442 fn test_operators() {
443 let mut lexer = Lexer::new("= != < <= > >=");
444 let tokens = lexer.tokenize().unwrap();
445
446 assert_eq!(
447 tokens,
448 vec![
449 Token::Eq,
450 Token::Ne,
451 Token::Lt,
452 Token::Le,
453 Token::Gt,
454 Token::Ge,
455 Token::Eof,
456 ]
457 );
458 }
459
460 #[test]
461 fn test_numbers() {
462 let mut lexer = Lexer::new("42 3.5");
463 let tokens = lexer.tokenize().unwrap();
464
465 assert_eq!(
466 tokens,
467 vec![Token::Integer(42), Token::Float(3.5), Token::Eof,]
468 );
469 }
470}