1use super::token::{Span, Token, TokenKind};
21use std::iter::Peekable;
22use std::str::Chars;
23
24#[derive(Debug, Clone, PartialEq)]
26pub struct LexError {
27 pub message: String,
28 pub span: Span,
29}
30
31impl LexError {
32 pub fn new(message: impl Into<String>, span: Span) -> Self {
33 Self {
34 message: message.into(),
35 span,
36 }
37 }
38}
39
40impl std::fmt::Display for LexError {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(
43 f,
44 "Lexer error at line {}, column {}: {}",
45 self.span.line, self.span.column, self.message
46 )
47 }
48}
49
50impl std::error::Error for LexError {}
51
52pub struct Lexer<'a> {
54 input: &'a str,
55 chars: Peekable<Chars<'a>>,
56 pos: usize,
57 line: usize,
58 column: usize,
59 tokens: Vec<Token>,
60 errors: Vec<LexError>,
61 placeholder_counter: u32,
63}
64
65impl<'a> Lexer<'a> {
66 pub fn new(input: &'a str) -> Self {
68 Self {
69 input,
70 chars: input.chars().peekable(),
71 pos: 0,
72 line: 1,
73 column: 1,
74 tokens: Vec::new(),
75 errors: Vec::new(),
76 placeholder_counter: 0,
77 }
78 }
79
80 pub fn tokenize(mut self) -> Result<Vec<Token>, Vec<LexError>> {
82 while !self.is_at_end() {
83 self.scan_token();
84 }
85
86 self.tokens.push(Token::new(
88 TokenKind::Eof,
89 Span::new(self.pos, self.pos, self.line, self.column),
90 "",
91 ));
92
93 if self.errors.is_empty() {
94 Ok(self.tokens)
95 } else {
96 Err(self.errors)
97 }
98 }
99
100 fn is_at_end(&mut self) -> bool {
101 self.chars.peek().is_none()
102 }
103
104 fn advance(&mut self) -> Option<char> {
105 let c = self.chars.next()?;
106 self.pos += c.len_utf8();
107 if c == '\n' {
108 self.line += 1;
109 self.column = 1;
110 } else {
111 self.column += 1;
112 }
113 Some(c)
114 }
115
116 fn peek(&mut self) -> Option<char> {
117 self.chars.peek().copied()
118 }
119
120 fn peek_next(&self) -> Option<char> {
121 let mut chars = self.chars.clone();
122 chars.next();
123 chars.next()
124 }
125
126 fn make_span(&self, start: usize, start_line: usize, start_col: usize) -> Span {
127 Span::new(start, self.pos, start_line, start_col)
128 }
129
130 fn scan_token(&mut self) {
131 let start = self.pos;
132 let start_line = self.line;
133 let start_col = self.column;
134
135 let c = match self.advance() {
136 Some(c) => c,
137 None => return,
138 };
139
140 match c {
141 ' ' | '\t' | '\r' | '\n' => {
143 }
145
146 '(' => self.add_token(TokenKind::LParen, start, start_line, start_col),
148 ')' => self.add_token(TokenKind::RParen, start, start_line, start_col),
149 '[' => self.add_token(TokenKind::LBracket, start, start_line, start_col),
150 ']' => self.add_token(TokenKind::RBracket, start, start_line, start_col),
151 ',' => self.add_token(TokenKind::Comma, start, start_line, start_col),
152 ';' => self.add_token(TokenKind::Semicolon, start, start_line, start_col),
153 '+' => self.add_token(TokenKind::Plus, start, start_line, start_col),
154 '*' => self.add_token(TokenKind::Star, start, start_line, start_col),
155 '/' => {
156 if self.peek() == Some('/') || self.peek() == Some('*') {
157 self.scan_comment(start, start_line, start_col);
158 } else {
159 self.add_token(TokenKind::Slash, start, start_line, start_col);
160 }
161 }
162 '%' => self.add_token(TokenKind::Percent, start, start_line, start_col),
163 '&' => self.add_token(TokenKind::BitAnd, start, start_line, start_col),
164 '~' => self.add_token(TokenKind::BitNot, start, start_line, start_col),
165 '?' => {
166 self.placeholder_counter += 1;
168 let span = self.make_span(start, start_line, start_col);
169 self.tokens.push(Token::new(
170 TokenKind::Placeholder(self.placeholder_counter),
171 span,
172 "?",
173 ));
174 }
175 '@' => self.add_token(TokenKind::At, start, start_line, start_col),
176
177 '-' => {
179 if self.peek() == Some('-') {
180 self.scan_line_comment(start, start_line, start_col);
182 } else if self.peek() == Some('>') {
183 self.advance();
184 if self.peek() == Some('>') {
185 self.advance();
186 self.add_token(TokenKind::DoubleArrow, start, start_line, start_col);
187 } else {
188 self.add_token(TokenKind::Arrow, start, start_line, start_col);
189 }
190 } else {
191 self.add_token(TokenKind::Minus, start, start_line, start_col);
192 }
193 }
194
195 '=' => self.add_token(TokenKind::Eq, start, start_line, start_col),
196
197 '!' => {
198 if self.peek() == Some('=') {
199 self.advance();
200 self.add_token(TokenKind::Ne, start, start_line, start_col);
201 } else {
202 self.add_error("Unexpected character '!'", start, start_line, start_col);
203 }
204 }
205
206 '<' => {
207 if self.peek() == Some('=') {
208 self.advance();
209 self.add_token(TokenKind::Le, start, start_line, start_col);
210 } else if self.peek() == Some('>') {
211 self.advance();
212 self.add_token(TokenKind::Ne, start, start_line, start_col);
213 } else if self.peek() == Some('<') {
214 self.advance();
215 self.add_token(TokenKind::LeftShift, start, start_line, start_col);
216 } else {
217 self.add_token(TokenKind::Lt, start, start_line, start_col);
218 }
219 }
220
221 '>' => {
222 if self.peek() == Some('=') {
223 self.advance();
224 self.add_token(TokenKind::Ge, start, start_line, start_col);
225 } else if self.peek() == Some('>') {
226 self.advance();
227 self.add_token(TokenKind::RightShift, start, start_line, start_col);
228 } else {
229 self.add_token(TokenKind::Gt, start, start_line, start_col);
230 }
231 }
232
233 '|' => {
234 if self.peek() == Some('|') {
235 self.advance();
236 self.add_token(TokenKind::Concat, start, start_line, start_col);
237 } else {
238 self.add_token(TokenKind::BitOr, start, start_line, start_col);
239 }
240 }
241
242 ':' => {
243 if self.peek() == Some(':') {
244 self.advance();
245 self.add_token(TokenKind::DoubleColon, start, start_line, start_col);
246 } else {
247 self.add_token(TokenKind::Colon, start, start_line, start_col);
248 }
249 }
250
251 '.' => {
252 if self.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
253 self.scan_number(start, start_line, start_col, true);
254 } else {
255 self.add_token(TokenKind::Dot, start, start_line, start_col);
256 }
257 }
258
259 '\'' => self.scan_string(start, start_line, start_col, '\''),
261 '"' => self.scan_quoted_identifier(start, start_line, start_col, '"'),
262 '`' => self.scan_quoted_identifier(start, start_line, start_col, '`'),
263
264 'X' | 'x' if self.peek() == Some('\'') => {
266 self.advance();
267 self.scan_blob(start, start_line, start_col);
268 }
269
270 '0'..='9' => self.scan_number(start, start_line, start_col, false),
272
273 'a'..='z' | 'A'..='Z' | '_' => self.scan_identifier(start, start_line, start_col),
275
276 '$' => self.scan_placeholder(start, start_line, start_col),
278
279 _ => {
280 self.add_error(
281 format!("Unexpected character '{}'", c),
282 start,
283 start_line,
284 start_col,
285 );
286 }
287 }
288 }
289
290 fn scan_string(&mut self, start: usize, start_line: usize, start_col: usize, quote: char) {
291 let mut value = String::new();
292
293 while let Some(c) = self.peek() {
294 if c == quote {
295 self.advance();
296 if self.peek() == Some(quote) {
298 self.advance();
299 value.push(quote);
300 } else {
301 let span = self.make_span(start, start_line, start_col);
303 self.tokens
304 .push(Token::new(TokenKind::String(value), span, ""));
305 return;
306 }
307 } else if c == '\\' {
308 self.advance();
309 if let Some(escaped) = self.advance() {
311 match escaped {
312 'n' => value.push('\n'),
313 'r' => value.push('\r'),
314 't' => value.push('\t'),
315 '\\' => value.push('\\'),
316 '\'' => value.push('\''),
317 '"' => value.push('"'),
318 '0' => value.push('\0'),
319 _ => {
320 value.push('\\');
321 value.push(escaped);
322 }
323 }
324 }
325 } else {
326 self.advance();
327 value.push(c);
328 }
329 }
330
331 self.add_error("Unterminated string literal", start, start_line, start_col);
332 }
333
334 fn scan_quoted_identifier(
335 &mut self,
336 start: usize,
337 start_line: usize,
338 start_col: usize,
339 quote: char,
340 ) {
341 let mut value = String::new();
342
343 while let Some(c) = self.peek() {
344 if c == quote {
345 self.advance();
346 if self.peek() == Some(quote) {
348 self.advance();
349 value.push(quote);
350 } else {
351 let span = self.make_span(start, start_line, start_col);
352 self.tokens
353 .push(Token::new(TokenKind::QuotedIdentifier(value), span, ""));
354 return;
355 }
356 } else {
357 self.advance();
358 value.push(c);
359 }
360 }
361
362 self.add_error(
363 "Unterminated quoted identifier",
364 start,
365 start_line,
366 start_col,
367 );
368 }
369
370 fn scan_number(
371 &mut self,
372 start: usize,
373 start_line: usize,
374 start_col: usize,
375 started_with_dot: bool,
376 ) {
377 let num_start = start;
378 let mut has_dot = started_with_dot;
379 let mut has_exp = false;
380
381 while let Some(c) = self.peek() {
383 if c.is_ascii_digit() {
384 self.advance();
385 } else if c == '.' && !has_dot && !has_exp {
386 if self.peek_next() == Some('.') {
388 break;
389 }
390 has_dot = true;
391 self.advance();
392 } else if (c == 'e' || c == 'E') && !has_exp {
393 has_exp = true;
394 self.advance();
395 if self.peek() == Some('+') || self.peek() == Some('-') {
397 self.advance();
398 }
399 } else {
400 break;
401 }
402 }
403
404 let literal = &self.input[num_start..self.pos];
405 let span = self.make_span(start, start_line, start_col);
406
407 if has_dot || has_exp {
408 match literal.parse::<f64>() {
409 Ok(n) => self
410 .tokens
411 .push(Token::new(TokenKind::Float(n), span, literal)),
412 Err(_) => self.add_error("Invalid float literal", start, start_line, start_col),
413 }
414 } else {
415 match literal.parse::<i64>() {
416 Ok(n) => self
417 .tokens
418 .push(Token::new(TokenKind::Integer(n), span, literal)),
419 Err(_) => self.add_error("Invalid integer literal", start, start_line, start_col),
420 }
421 }
422 }
423
424 fn scan_identifier(&mut self, start: usize, start_line: usize, start_col: usize) {
425 while let Some(c) = self.peek() {
426 if c.is_ascii_alphanumeric() || c == '_' {
427 self.advance();
428 } else {
429 break;
430 }
431 }
432
433 let literal = &self.input[start..self.pos];
434 let span = self.make_span(start, start_line, start_col);
435
436 let kind = TokenKind::from_keyword(literal)
438 .unwrap_or_else(|| TokenKind::Identifier(literal.to_string()));
439
440 self.tokens.push(Token::new(kind, span, literal));
441 }
442
443 fn scan_placeholder(&mut self, start: usize, start_line: usize, start_col: usize) {
444 let mut num = String::new();
445
446 while let Some(c) = self.peek() {
447 if c.is_ascii_digit() {
448 self.advance();
449 num.push(c);
450 } else {
451 break;
452 }
453 }
454
455 let span = self.make_span(start, start_line, start_col);
456
457 if num.is_empty() {
458 self.add_error("Expected number after $", start, start_line, start_col);
459 } else if let Ok(n) = num.parse::<u32>() {
460 self.tokens.push(Token::new(
461 TokenKind::Placeholder(n),
462 span,
463 &self.input[start..self.pos],
464 ));
465 } else {
466 self.add_error("Invalid placeholder number", start, start_line, start_col);
467 }
468 }
469
470 fn scan_comment(&mut self, start: usize, start_line: usize, start_col: usize) {
471 self.advance(); if self.peek() == Some('*') || self.input[start..self.pos].ends_with('*') {
474 let mut depth = 1;
476
477 while depth > 0 && !self.is_at_end() {
478 let c = self.peek();
479 let next = self.peek_next();
480
481 if c == Some('*') && next == Some('/') {
482 self.advance();
483 self.advance();
484 depth -= 1;
485 } else if c == Some('/') && next == Some('*') {
486 self.advance();
487 self.advance();
488 depth += 1;
489 } else {
490 self.advance();
491 }
492 }
493
494 if depth > 0 {
495 self.add_error("Unterminated block comment", start, start_line, start_col);
496 }
497 } else {
498 while let Some(c) = self.peek() {
500 if c == '\n' {
501 break;
502 }
503 self.advance();
504 }
505 }
506 }
508
509 fn scan_line_comment(&mut self, _start: usize, _start_line: usize, _start_col: usize) {
510 self.advance(); while let Some(c) = self.peek() {
513 if c == '\n' {
514 break;
515 }
516 self.advance();
517 }
518 }
520
521 fn scan_blob(&mut self, start: usize, start_line: usize, start_col: usize) {
522 let mut hex = String::new();
523
524 while let Some(c) = self.peek() {
525 if c == '\'' {
526 self.advance();
527 break;
528 } else if c.is_ascii_hexdigit() {
529 self.advance();
530 hex.push(c);
531 } else if c.is_whitespace() {
532 self.advance(); } else {
534 self.add_error(
535 "Invalid hex digit in blob literal",
536 start,
537 start_line,
538 start_col,
539 );
540 return;
541 }
542 }
543
544 if !hex.len().is_multiple_of(2) {
545 self.add_error(
546 "Blob literal must have even number of hex digits",
547 start,
548 start_line,
549 start_col,
550 );
551 return;
552 }
553
554 let bytes: Result<Vec<u8>, _> = (0..hex.len())
555 .step_by(2)
556 .map(|i| u8::from_str_radix(&hex[i..i + 2], 16))
557 .collect();
558
559 match bytes {
560 Ok(data) => {
561 let span = self.make_span(start, start_line, start_col);
562 self.tokens
563 .push(Token::new(TokenKind::Blob(data), span, ""));
564 }
565 Err(_) => {
566 self.add_error("Invalid blob literal", start, start_line, start_col);
567 }
568 }
569 }
570
571 fn add_token(&mut self, kind: TokenKind, start: usize, start_line: usize, start_col: usize) {
572 let span = self.make_span(start, start_line, start_col);
573 let literal = &self.input[start..self.pos];
574 self.tokens.push(Token::new(kind, span, literal));
575 }
576
577 fn add_error(
578 &mut self,
579 message: impl Into<String>,
580 start: usize,
581 start_line: usize,
582 start_col: usize,
583 ) {
584 let span = self.make_span(start, start_line, start_col);
585 self.errors.push(LexError::new(message, span));
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
594 fn test_simple_select() {
595 let tokens = Lexer::new("SELECT * FROM users").tokenize().unwrap();
596 assert_eq!(tokens.len(), 5); assert_eq!(tokens[0].kind, TokenKind::Select);
598 assert_eq!(tokens[1].kind, TokenKind::Star);
599 assert_eq!(tokens[2].kind, TokenKind::From);
600 assert!(matches!(tokens[3].kind, TokenKind::Identifier(_)));
601 }
602
603 #[test]
604 fn test_string_literal() {
605 let tokens = Lexer::new("SELECT 'hello''world'").tokenize().unwrap();
606 assert!(matches!(&tokens[1].kind, TokenKind::String(s) if s == "hello'world"));
607 }
608
609 #[test]
610 #[allow(clippy::approx_constant)]
611 fn test_numbers() {
612 let tokens = Lexer::new("42 3.14 1e10 .5").tokenize().unwrap();
613 assert!(matches!(tokens[0].kind, TokenKind::Integer(42)));
614 assert!(matches!(tokens[1].kind, TokenKind::Float(f) if (f - 3.14).abs() < 0.001));
615 assert!(matches!(tokens[2].kind, TokenKind::Float(_)));
616 assert!(matches!(tokens[3].kind, TokenKind::Float(f) if (f - 0.5).abs() < 0.001));
617 }
618
619 #[test]
620 fn test_operators() {
621 let tokens = Lexer::new("= != <> < <= > >= || ->").tokenize().unwrap();
622 assert_eq!(tokens[0].kind, TokenKind::Eq);
623 assert_eq!(tokens[1].kind, TokenKind::Ne);
624 assert_eq!(tokens[2].kind, TokenKind::Ne);
625 assert_eq!(tokens[3].kind, TokenKind::Lt);
626 assert_eq!(tokens[4].kind, TokenKind::Le);
627 assert_eq!(tokens[5].kind, TokenKind::Gt);
628 assert_eq!(tokens[6].kind, TokenKind::Ge);
629 assert_eq!(tokens[7].kind, TokenKind::Concat);
630 assert_eq!(tokens[8].kind, TokenKind::Arrow);
631 }
632
633 #[test]
634 fn test_keywords() {
635 let tokens = Lexer::new("SELECT INSERT UPDATE DELETE FROM WHERE")
636 .tokenize()
637 .unwrap();
638 assert_eq!(tokens[0].kind, TokenKind::Select);
639 assert_eq!(tokens[1].kind, TokenKind::Insert);
640 assert_eq!(tokens[2].kind, TokenKind::Update);
641 assert_eq!(tokens[3].kind, TokenKind::Delete);
642 assert_eq!(tokens[4].kind, TokenKind::From);
643 assert_eq!(tokens[5].kind, TokenKind::Where);
644 }
645
646 #[test]
647 fn test_placeholder() {
648 let tokens = Lexer::new("$1 $2 $10").tokenize().unwrap();
649 assert!(matches!(tokens[0].kind, TokenKind::Placeholder(1)));
650 assert!(matches!(tokens[1].kind, TokenKind::Placeholder(2)));
651 assert!(matches!(tokens[2].kind, TokenKind::Placeholder(10)));
652 }
653
654 #[test]
655 fn test_line_comment() {
656 let tokens = Lexer::new("SELECT -- comment\n* FROM users")
657 .tokenize()
658 .unwrap();
659 assert_eq!(tokens.len(), 5); assert_eq!(tokens[0].kind, TokenKind::Select);
661 assert_eq!(tokens[1].kind, TokenKind::Star);
662 }
663
664 #[test]
665 fn test_blob_literal() {
666 let tokens = Lexer::new("X'48454C4C4F'").tokenize().unwrap();
667 assert!(matches!(&tokens[0].kind, TokenKind::Blob(b) if b == b"HELLO"));
668 }
669}