1use super::token::{Span, Token, TokenKind};
24use std::iter::Peekable;
25use std::str::Chars;
26
27#[derive(Debug, Clone, PartialEq)]
29pub struct LexError {
30 pub message: String,
31 pub span: Span,
32}
33
34impl LexError {
35 pub fn new(message: impl Into<String>, span: Span) -> Self {
36 Self {
37 message: message.into(),
38 span,
39 }
40 }
41}
42
43impl std::fmt::Display for LexError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(
46 f,
47 "Lexer error at line {}, column {}: {}",
48 self.span.line, self.span.column, self.message
49 )
50 }
51}
52
53impl std::error::Error for LexError {}
54
55pub struct Lexer<'a> {
57 input: &'a str,
58 chars: Peekable<Chars<'a>>,
59 pos: usize,
60 line: usize,
61 column: usize,
62 tokens: Vec<Token>,
63 errors: Vec<LexError>,
64 placeholder_counter: u32,
66}
67
68impl<'a> Lexer<'a> {
69 pub fn new(input: &'a str) -> Self {
71 Self {
72 input,
73 chars: input.chars().peekable(),
74 pos: 0,
75 line: 1,
76 column: 1,
77 tokens: Vec::new(),
78 errors: Vec::new(),
79 placeholder_counter: 0,
80 }
81 }
82
83 pub fn tokenize(mut self) -> Result<Vec<Token>, Vec<LexError>> {
85 while !self.is_at_end() {
86 self.scan_token();
87 }
88
89 self.tokens.push(Token::new(
91 TokenKind::Eof,
92 Span::new(self.pos, self.pos, self.line, self.column),
93 "",
94 ));
95
96 if self.errors.is_empty() {
97 Ok(self.tokens)
98 } else {
99 Err(self.errors)
100 }
101 }
102
103 fn is_at_end(&mut self) -> bool {
104 self.chars.peek().is_none()
105 }
106
107 fn advance(&mut self) -> Option<char> {
108 let c = self.chars.next()?;
109 self.pos += c.len_utf8();
110 if c == '\n' {
111 self.line += 1;
112 self.column = 1;
113 } else {
114 self.column += 1;
115 }
116 Some(c)
117 }
118
119 fn peek(&mut self) -> Option<char> {
120 self.chars.peek().copied()
121 }
122
123 fn peek_next(&self) -> Option<char> {
124 let mut chars = self.chars.clone();
125 chars.next();
126 chars.next()
127 }
128
129 fn make_span(&self, start: usize, start_line: usize, start_col: usize) -> Span {
130 Span::new(start, self.pos, start_line, start_col)
131 }
132
133 fn scan_token(&mut self) {
134 let start = self.pos;
135 let start_line = self.line;
136 let start_col = self.column;
137
138 let c = match self.advance() {
139 Some(c) => c,
140 None => return,
141 };
142
143 match c {
144 ' ' | '\t' | '\r' | '\n' => {
146 }
148
149 '(' => self.add_token(TokenKind::LParen, start, start_line, start_col),
151 ')' => self.add_token(TokenKind::RParen, start, start_line, start_col),
152 '[' => self.add_token(TokenKind::LBracket, start, start_line, start_col),
153 ']' => self.add_token(TokenKind::RBracket, start, start_line, start_col),
154 ',' => self.add_token(TokenKind::Comma, start, start_line, start_col),
155 ';' => self.add_token(TokenKind::Semicolon, start, start_line, start_col),
156 '+' => self.add_token(TokenKind::Plus, start, start_line, start_col),
157 '*' => self.add_token(TokenKind::Star, start, start_line, start_col),
158 '/' => {
159 if self.peek() == Some('/') || self.peek() == Some('*') {
160 self.scan_comment(start, start_line, start_col);
161 } else {
162 self.add_token(TokenKind::Slash, start, start_line, start_col);
163 }
164 }
165 '%' => self.add_token(TokenKind::Percent, start, start_line, start_col),
166 '&' => self.add_token(TokenKind::BitAnd, start, start_line, start_col),
167 '~' => self.add_token(TokenKind::BitNot, start, start_line, start_col),
168 '?' => {
169 self.placeholder_counter += 1;
171 let span = self.make_span(start, start_line, start_col);
172 self.tokens.push(Token::new(
173 TokenKind::Placeholder(self.placeholder_counter),
174 span,
175 "?",
176 ));
177 }
178 '@' => self.add_token(TokenKind::At, start, start_line, start_col),
179
180 '-' => {
182 if self.peek() == Some('-') {
183 self.scan_line_comment(start, start_line, start_col);
185 } else if self.peek() == Some('>') {
186 self.advance();
187 if self.peek() == Some('>') {
188 self.advance();
189 self.add_token(TokenKind::DoubleArrow, start, start_line, start_col);
190 } else {
191 self.add_token(TokenKind::Arrow, start, start_line, start_col);
192 }
193 } else {
194 self.add_token(TokenKind::Minus, start, start_line, start_col);
195 }
196 }
197
198 '=' => self.add_token(TokenKind::Eq, start, start_line, start_col),
199
200 '!' => {
201 if self.peek() == Some('=') {
202 self.advance();
203 self.add_token(TokenKind::Ne, start, start_line, start_col);
204 } else {
205 self.add_error("Unexpected character '!'", start, start_line, start_col);
206 }
207 }
208
209 '<' => {
210 if self.peek() == Some('=') {
211 self.advance();
212 self.add_token(TokenKind::Le, start, start_line, start_col);
213 } else if self.peek() == Some('>') {
214 self.advance();
215 self.add_token(TokenKind::Ne, start, start_line, start_col);
216 } else if self.peek() == Some('<') {
217 self.advance();
218 self.add_token(TokenKind::LeftShift, start, start_line, start_col);
219 } else {
220 self.add_token(TokenKind::Lt, start, start_line, start_col);
221 }
222 }
223
224 '>' => {
225 if self.peek() == Some('=') {
226 self.advance();
227 self.add_token(TokenKind::Ge, start, start_line, start_col);
228 } else if self.peek() == Some('>') {
229 self.advance();
230 self.add_token(TokenKind::RightShift, start, start_line, start_col);
231 } else {
232 self.add_token(TokenKind::Gt, start, start_line, start_col);
233 }
234 }
235
236 '|' => {
237 if self.peek() == Some('|') {
238 self.advance();
239 self.add_token(TokenKind::Concat, start, start_line, start_col);
240 } else {
241 self.add_token(TokenKind::BitOr, start, start_line, start_col);
242 }
243 }
244
245 ':' => {
246 if self.peek() == Some(':') {
247 self.advance();
248 self.add_token(TokenKind::DoubleColon, start, start_line, start_col);
249 } else {
250 self.add_token(TokenKind::Colon, start, start_line, start_col);
251 }
252 }
253
254 '.' => {
255 if self.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
256 self.scan_number(start, start_line, start_col, true);
257 } else {
258 self.add_token(TokenKind::Dot, start, start_line, start_col);
259 }
260 }
261
262 '\'' => self.scan_string(start, start_line, start_col, '\''),
264 '"' => self.scan_quoted_identifier(start, start_line, start_col, '"'),
265 '`' => self.scan_quoted_identifier(start, start_line, start_col, '`'),
266
267 'X' | 'x' if self.peek() == Some('\'') => {
269 self.advance();
270 self.scan_blob(start, start_line, start_col);
271 }
272
273 '0'..='9' => self.scan_number(start, start_line, start_col, false),
275
276 'a'..='z' | 'A'..='Z' | '_' => self.scan_identifier(start, start_line, start_col),
278
279 '$' => self.scan_placeholder(start, start_line, start_col),
281
282 _ => {
283 self.add_error(
284 format!("Unexpected character '{}'", c),
285 start,
286 start_line,
287 start_col,
288 );
289 }
290 }
291 }
292
293 fn scan_string(&mut self, start: usize, start_line: usize, start_col: usize, quote: char) {
294 let mut value = String::new();
295
296 while let Some(c) = self.peek() {
297 if c == quote {
298 self.advance();
299 if self.peek() == Some(quote) {
301 self.advance();
302 value.push(quote);
303 } else {
304 let span = self.make_span(start, start_line, start_col);
306 self.tokens
307 .push(Token::new(TokenKind::String(value), span, ""));
308 return;
309 }
310 } else if c == '\\' {
311 self.advance();
312 if let Some(escaped) = self.advance() {
314 match escaped {
315 'n' => value.push('\n'),
316 'r' => value.push('\r'),
317 't' => value.push('\t'),
318 '\\' => value.push('\\'),
319 '\'' => value.push('\''),
320 '"' => value.push('"'),
321 '0' => value.push('\0'),
322 _ => {
323 value.push('\\');
324 value.push(escaped);
325 }
326 }
327 }
328 } else {
329 self.advance();
330 value.push(c);
331 }
332 }
333
334 self.add_error("Unterminated string literal", start, start_line, start_col);
335 }
336
337 fn scan_quoted_identifier(
338 &mut self,
339 start: usize,
340 start_line: usize,
341 start_col: usize,
342 quote: char,
343 ) {
344 let mut value = String::new();
345
346 while let Some(c) = self.peek() {
347 if c == quote {
348 self.advance();
349 if self.peek() == Some(quote) {
351 self.advance();
352 value.push(quote);
353 } else {
354 let span = self.make_span(start, start_line, start_col);
355 self.tokens
356 .push(Token::new(TokenKind::QuotedIdentifier(value), span, ""));
357 return;
358 }
359 } else {
360 self.advance();
361 value.push(c);
362 }
363 }
364
365 self.add_error(
366 "Unterminated quoted identifier",
367 start,
368 start_line,
369 start_col,
370 );
371 }
372
373 fn scan_number(
374 &mut self,
375 start: usize,
376 start_line: usize,
377 start_col: usize,
378 started_with_dot: bool,
379 ) {
380 let num_start = start;
381 let mut has_dot = started_with_dot;
382 let mut has_exp = false;
383
384 while let Some(c) = self.peek() {
386 if c.is_ascii_digit() {
387 self.advance();
388 } else if c == '.' && !has_dot && !has_exp {
389 if self.peek_next() == Some('.') {
391 break;
392 }
393 has_dot = true;
394 self.advance();
395 } else if (c == 'e' || c == 'E') && !has_exp {
396 has_exp = true;
397 self.advance();
398 if self.peek() == Some('+') || self.peek() == Some('-') {
400 self.advance();
401 }
402 } else {
403 break;
404 }
405 }
406
407 let literal = &self.input[num_start..self.pos];
408 let span = self.make_span(start, start_line, start_col);
409
410 if has_dot || has_exp {
411 match literal.parse::<f64>() {
412 Ok(n) => self
413 .tokens
414 .push(Token::new(TokenKind::Float(n), span, literal)),
415 Err(_) => self.add_error("Invalid float literal", start, start_line, start_col),
416 }
417 } else {
418 match literal.parse::<i64>() {
419 Ok(n) => self
420 .tokens
421 .push(Token::new(TokenKind::Integer(n), span, literal)),
422 Err(_) => self.add_error("Invalid integer literal", start, start_line, start_col),
423 }
424 }
425 }
426
427 fn scan_identifier(&mut self, start: usize, start_line: usize, start_col: usize) {
428 while let Some(c) = self.peek() {
429 if c.is_ascii_alphanumeric() || c == '_' {
430 self.advance();
431 } else {
432 break;
433 }
434 }
435
436 let literal = &self.input[start..self.pos];
437 let span = self.make_span(start, start_line, start_col);
438
439 let kind = TokenKind::from_keyword(literal)
441 .unwrap_or_else(|| TokenKind::Identifier(literal.to_string()));
442
443 self.tokens.push(Token::new(kind, span, literal));
444 }
445
446 fn scan_placeholder(&mut self, start: usize, start_line: usize, start_col: usize) {
447 let mut num = String::new();
448
449 while let Some(c) = self.peek() {
450 if c.is_ascii_digit() {
451 self.advance();
452 num.push(c);
453 } else {
454 break;
455 }
456 }
457
458 let span = self.make_span(start, start_line, start_col);
459
460 if num.is_empty() {
461 self.add_error("Expected number after $", start, start_line, start_col);
462 } else if let Ok(n) = num.parse::<u32>() {
463 self.tokens.push(Token::new(
464 TokenKind::Placeholder(n),
465 span,
466 &self.input[start..self.pos],
467 ));
468 } else {
469 self.add_error("Invalid placeholder number", start, start_line, start_col);
470 }
471 }
472
473 fn scan_comment(&mut self, start: usize, start_line: usize, start_col: usize) {
474 self.advance(); if self.peek() == Some('*') || self.input[start..self.pos].ends_with('*') {
477 let mut depth = 1;
479
480 while depth > 0 && !self.is_at_end() {
481 let c = self.peek();
482 let next = self.peek_next();
483
484 if c == Some('*') && next == Some('/') {
485 self.advance();
486 self.advance();
487 depth -= 1;
488 } else if c == Some('/') && next == Some('*') {
489 self.advance();
490 self.advance();
491 depth += 1;
492 } else {
493 self.advance();
494 }
495 }
496
497 if depth > 0 {
498 self.add_error("Unterminated block comment", start, start_line, start_col);
499 }
500 } else {
501 while let Some(c) = self.peek() {
503 if c == '\n' {
504 break;
505 }
506 self.advance();
507 }
508 }
509 }
511
512 fn scan_line_comment(&mut self, _start: usize, _start_line: usize, _start_col: usize) {
513 self.advance(); while let Some(c) = self.peek() {
516 if c == '\n' {
517 break;
518 }
519 self.advance();
520 }
521 }
523
524 fn scan_blob(&mut self, start: usize, start_line: usize, start_col: usize) {
525 let mut hex = String::new();
526
527 while let Some(c) = self.peek() {
528 if c == '\'' {
529 self.advance();
530 break;
531 } else if c.is_ascii_hexdigit() {
532 self.advance();
533 hex.push(c);
534 } else if c.is_whitespace() {
535 self.advance(); } else {
537 self.add_error(
538 "Invalid hex digit in blob literal",
539 start,
540 start_line,
541 start_col,
542 );
543 return;
544 }
545 }
546
547 if !hex.len().is_multiple_of(2) {
548 self.add_error(
549 "Blob literal must have even number of hex digits",
550 start,
551 start_line,
552 start_col,
553 );
554 return;
555 }
556
557 let bytes: Result<Vec<u8>, _> = (0..hex.len())
558 .step_by(2)
559 .map(|i| u8::from_str_radix(&hex[i..i + 2], 16))
560 .collect();
561
562 match bytes {
563 Ok(data) => {
564 let span = self.make_span(start, start_line, start_col);
565 self.tokens
566 .push(Token::new(TokenKind::Blob(data), span, ""));
567 }
568 Err(_) => {
569 self.add_error("Invalid blob literal", start, start_line, start_col);
570 }
571 }
572 }
573
574 fn add_token(&mut self, kind: TokenKind, start: usize, start_line: usize, start_col: usize) {
575 let span = self.make_span(start, start_line, start_col);
576 let literal = &self.input[start..self.pos];
577 self.tokens.push(Token::new(kind, span, literal));
578 }
579
580 fn add_error(
581 &mut self,
582 message: impl Into<String>,
583 start: usize,
584 start_line: usize,
585 start_col: usize,
586 ) {
587 let span = self.make_span(start, start_line, start_col);
588 self.errors.push(LexError::new(message, span));
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_simple_select() {
598 let tokens = Lexer::new("SELECT * FROM users").tokenize().unwrap();
599 assert_eq!(tokens.len(), 5); assert_eq!(tokens[0].kind, TokenKind::Select);
601 assert_eq!(tokens[1].kind, TokenKind::Star);
602 assert_eq!(tokens[2].kind, TokenKind::From);
603 assert!(matches!(tokens[3].kind, TokenKind::Identifier(_)));
604 }
605
606 #[test]
607 fn test_string_literal() {
608 let tokens = Lexer::new("SELECT 'hello''world'").tokenize().unwrap();
609 assert!(matches!(&tokens[1].kind, TokenKind::String(s) if s == "hello'world"));
610 }
611
612 #[test]
613 #[allow(clippy::approx_constant)]
614 fn test_numbers() {
615 let tokens = Lexer::new("42 3.14 1e10 .5").tokenize().unwrap();
616 assert!(matches!(tokens[0].kind, TokenKind::Integer(42)));
617 assert!(matches!(tokens[1].kind, TokenKind::Float(f) if (f - 3.14).abs() < 0.001));
618 assert!(matches!(tokens[2].kind, TokenKind::Float(_)));
619 assert!(matches!(tokens[3].kind, TokenKind::Float(f) if (f - 0.5).abs() < 0.001));
620 }
621
622 #[test]
623 fn test_operators() {
624 let tokens = Lexer::new("= != <> < <= > >= || ->").tokenize().unwrap();
625 assert_eq!(tokens[0].kind, TokenKind::Eq);
626 assert_eq!(tokens[1].kind, TokenKind::Ne);
627 assert_eq!(tokens[2].kind, TokenKind::Ne);
628 assert_eq!(tokens[3].kind, TokenKind::Lt);
629 assert_eq!(tokens[4].kind, TokenKind::Le);
630 assert_eq!(tokens[5].kind, TokenKind::Gt);
631 assert_eq!(tokens[6].kind, TokenKind::Ge);
632 assert_eq!(tokens[7].kind, TokenKind::Concat);
633 assert_eq!(tokens[8].kind, TokenKind::Arrow);
634 }
635
636 #[test]
637 fn test_keywords() {
638 let tokens = Lexer::new("SELECT INSERT UPDATE DELETE FROM WHERE")
639 .tokenize()
640 .unwrap();
641 assert_eq!(tokens[0].kind, TokenKind::Select);
642 assert_eq!(tokens[1].kind, TokenKind::Insert);
643 assert_eq!(tokens[2].kind, TokenKind::Update);
644 assert_eq!(tokens[3].kind, TokenKind::Delete);
645 assert_eq!(tokens[4].kind, TokenKind::From);
646 assert_eq!(tokens[5].kind, TokenKind::Where);
647 }
648
649 #[test]
650 fn test_placeholder() {
651 let tokens = Lexer::new("$1 $2 $10").tokenize().unwrap();
652 assert!(matches!(tokens[0].kind, TokenKind::Placeholder(1)));
653 assert!(matches!(tokens[1].kind, TokenKind::Placeholder(2)));
654 assert!(matches!(tokens[2].kind, TokenKind::Placeholder(10)));
655 }
656
657 #[test]
658 fn test_line_comment() {
659 let tokens = Lexer::new("SELECT -- comment\n* FROM users")
660 .tokenize()
661 .unwrap();
662 assert_eq!(tokens.len(), 5); assert_eq!(tokens[0].kind, TokenKind::Select);
664 assert_eq!(tokens[1].kind, TokenKind::Star);
665 }
666
667 #[test]
668 fn test_blob_literal() {
669 let tokens = Lexer::new("X'48454C4C4F'").tokenize().unwrap();
670 assert!(matches!(&tokens[0].kind, TokenKind::Blob(b) if b == b"HELLO"));
671 }
672}