1use crate::theme::Theme;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SqlToken {
28 Keyword,
30 String,
32 Number,
34 Comment,
36 Operator,
38 Identifier,
40 Punctuation,
42 Whitespace,
44 Parameter,
46}
47
48#[derive(Debug, Clone)]
50pub struct SqlSegment {
51 pub text: String,
53 pub token: SqlToken,
55}
56
57#[derive(Debug, Clone)]
61pub struct SqlHighlighter {
62 theme: Theme,
64}
65
66impl SqlHighlighter {
67 const KEYWORDS: &'static [&'static str] = &[
69 "SELECT",
71 "INSERT",
72 "UPDATE",
73 "DELETE",
74 "FROM",
75 "WHERE",
76 "SET",
77 "VALUES",
78 "INTO",
79 "JOIN",
80 "LEFT",
81 "RIGHT",
82 "INNER",
83 "OUTER",
84 "FULL",
85 "CROSS",
86 "ON",
87 "USING",
88 "AS",
89 "DISTINCT",
90 "ALL",
91 "ORDER",
92 "BY",
93 "ASC",
94 "DESC",
95 "NULLS",
96 "FIRST",
97 "LAST",
98 "LIMIT",
99 "OFFSET",
100 "FETCH",
101 "NEXT",
102 "ROWS",
103 "ONLY",
104 "GROUP",
105 "HAVING",
106 "UNION",
107 "INTERSECT",
108 "EXCEPT",
109 "CASE",
110 "WHEN",
111 "THEN",
112 "ELSE",
113 "END",
114 "BETWEEN",
115 "IN",
116 "LIKE",
117 "ILIKE",
118 "SIMILAR",
119 "TO",
120 "EXISTS",
121 "ANY",
122 "SOME",
123 "RETURNING",
124 "WITH",
125 "RECURSIVE",
126 "CREATE",
128 "ALTER",
129 "DROP",
130 "TRUNCATE",
131 "TABLE",
132 "INDEX",
133 "VIEW",
134 "SCHEMA",
135 "DATABASE",
136 "CONSTRAINT",
137 "PRIMARY",
138 "KEY",
139 "FOREIGN",
140 "REFERENCES",
141 "UNIQUE",
142 "CHECK",
143 "DEFAULT",
144 "NOT",
145 "NULL",
146 "AUTO_INCREMENT",
147 "AUTOINCREMENT",
148 "SERIAL",
149 "IF",
150 "CASCADE",
151 "RESTRICT",
152 "BEGIN",
154 "COMMIT",
155 "ROLLBACK",
156 "SAVEPOINT",
157 "TRANSACTION",
158 "START",
159 "RELEASE",
160 "INTEGER",
162 "INT",
163 "BIGINT",
164 "SMALLINT",
165 "TINYINT",
166 "REAL",
167 "FLOAT",
168 "DOUBLE",
169 "PRECISION",
170 "DECIMAL",
171 "NUMERIC",
172 "VARCHAR",
173 "CHAR",
174 "TEXT",
175 "BLOB",
176 "BYTEA",
177 "BOOLEAN",
178 "BOOL",
179 "DATE",
180 "TIME",
181 "TIMESTAMP",
182 "INTERVAL",
183 "UUID",
184 "JSON",
185 "JSONB",
186 "ARRAY",
187 "COUNT",
189 "SUM",
190 "AVG",
191 "MIN",
192 "MAX",
193 "COALESCE",
194 "NULLIF",
195 "CAST",
196 "EXTRACT",
197 "NOW",
198 "CURRENT_DATE",
199 "CURRENT_TIME",
200 "CURRENT_TIMESTAMP",
201 "LOWER",
202 "UPPER",
203 "TRIM",
204 "SUBSTRING",
205 "LENGTH",
206 "CONCAT",
207 "REPLACE",
208 ];
209
210 #[must_use]
212 pub fn new() -> Self {
213 Self {
214 theme: Theme::default(),
215 }
216 }
217
218 #[must_use]
220 pub fn with_theme(theme: Theme) -> Self {
221 Self { theme }
222 }
223
224 #[must_use]
226 pub fn theme(mut self, theme: Theme) -> Self {
227 self.theme = theme;
228 self
229 }
230
231 fn is_keyword(word: &str) -> bool {
233 let upper = word.to_uppercase();
234 Self::KEYWORDS.contains(&upper.as_str())
235 }
236
237 fn is_operator_keyword(word: &str) -> bool {
239 let upper = word.to_uppercase();
240 matches!(
241 upper.as_str(),
242 "AND" | "OR" | "NOT" | "IS" | "BETWEEN" | "LIKE" | "ILIKE" | "IN"
243 )
244 }
245
246 #[must_use]
248 pub fn tokenize(&self, sql: &str) -> Vec<SqlSegment> {
249 let mut segments = Vec::new();
250 let chars: Vec<char> = sql.chars().collect();
251 let mut i = 0;
252
253 while i < chars.len() {
254 let c = chars[i];
255
256 if c.is_whitespace() {
258 let start = i;
259 while i < chars.len() && chars[i].is_whitespace() {
260 i += 1;
261 }
262 segments.push(SqlSegment {
263 text: chars[start..i].iter().collect(),
264 token: SqlToken::Whitespace,
265 });
266 continue;
267 }
268
269 if c == '-' && i + 1 < chars.len() && chars[i + 1] == '-' {
271 let start = i;
272 while i < chars.len() && chars[i] != '\n' {
273 i += 1;
274 }
275 segments.push(SqlSegment {
276 text: chars[start..i].iter().collect(),
277 token: SqlToken::Comment,
278 });
279 continue;
280 }
281
282 if c == '/' && i + 1 < chars.len() && chars[i + 1] == '*' {
284 let start = i;
285 i += 2;
286 while i + 1 < chars.len() && !(chars[i] == '*' && chars[i + 1] == '/') {
287 i += 1;
288 }
289 if i + 1 < chars.len() {
290 i += 2; }
292 segments.push(SqlSegment {
293 text: chars[start..i].iter().collect(),
294 token: SqlToken::Comment,
295 });
296 continue;
297 }
298
299 if c == '\'' {
301 let start = i;
302 i += 1;
303 while i < chars.len() {
304 if chars[i] == '\'' {
305 if i + 1 < chars.len() && chars[i + 1] == '\'' {
306 i += 2; } else {
308 i += 1;
309 break;
310 }
311 } else {
312 i += 1;
313 }
314 }
315 segments.push(SqlSegment {
316 text: chars[start..i].iter().collect(),
317 token: SqlToken::String,
318 });
319 continue;
320 }
321
322 if c == '"' {
324 let start = i;
325 i += 1;
326 while i < chars.len() && chars[i] != '"' {
327 i += 1;
328 }
329 if i < chars.len() {
330 i += 1;
331 }
332 segments.push(SqlSegment {
333 text: chars[start..i].iter().collect(),
334 token: SqlToken::Identifier,
335 });
336 continue;
337 }
338
339 if c == '$' || c == '?' {
341 let start = i;
342 i += 1;
343 while i < chars.len() && chars[i].is_ascii_digit() {
344 i += 1;
345 }
346 segments.push(SqlSegment {
347 text: chars[start..i].iter().collect(),
348 token: SqlToken::Parameter,
349 });
350 continue;
351 }
352
353 if c == ':'
355 && i + 1 < chars.len()
356 && (chars[i + 1].is_alphabetic() || chars[i + 1] == '_')
357 {
358 let start = i;
359 i += 1;
360 while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
361 i += 1;
362 }
363 segments.push(SqlSegment {
364 text: chars[start..i].iter().collect(),
365 token: SqlToken::Parameter,
366 });
367 continue;
368 }
369
370 if c.is_ascii_digit()
372 || (c == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
373 {
374 let start = i;
375 let mut has_dot = c == '.';
376 i += 1;
377 while i < chars.len() {
378 if chars[i].is_ascii_digit() {
379 i += 1;
380 } else if chars[i] == '.' && !has_dot {
381 has_dot = true;
382 i += 1;
383 } else if chars[i] == 'e' || chars[i] == 'E' {
384 i += 1;
385 if i < chars.len() && (chars[i] == '+' || chars[i] == '-') {
386 i += 1;
387 }
388 } else {
389 break;
390 }
391 }
392 segments.push(SqlSegment {
393 text: chars[start..i].iter().collect(),
394 token: SqlToken::Number,
395 });
396 continue;
397 }
398
399 if c.is_alphabetic() || c == '_' {
401 let start = i;
402 while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
403 i += 1;
404 }
405 let word: String = chars[start..i].iter().collect();
406 let token = if Self::is_operator_keyword(&word) {
407 SqlToken::Operator
408 } else if Self::is_keyword(&word) {
409 SqlToken::Keyword
410 } else {
411 SqlToken::Identifier
412 };
413 segments.push(SqlSegment { text: word, token });
414 continue;
415 }
416
417 if matches!(c, '=' | '<' | '>' | '!' | '+' | '-' | '*' | '/' | '%' | '|') {
419 let start = i;
420 i += 1;
421 if i < chars.len() {
423 let next = chars[i];
424 let is_two_char_op =
425 matches!((c, next), ('<', '>' | '=') | ('>' | '!', '=') | ('|', '|'));
426 if is_two_char_op {
427 i += 1;
428 }
429 }
430 segments.push(SqlSegment {
431 text: chars[start..i].iter().collect(),
432 token: SqlToken::Operator,
433 });
434 continue;
435 }
436
437 if matches!(c, '(' | ')' | ',' | ';' | '.') {
439 segments.push(SqlSegment {
440 text: c.to_string(),
441 token: SqlToken::Punctuation,
442 });
443 i += 1;
444 continue;
445 }
446
447 segments.push(SqlSegment {
449 text: c.to_string(),
450 token: SqlToken::Identifier,
451 });
452 i += 1;
453 }
454
455 segments
456 }
457
458 fn color_for_token(&self, token: SqlToken) -> String {
460 match token {
461 SqlToken::Keyword => self.theme.sql_keyword.color_code(),
462 SqlToken::String => self.theme.sql_string.color_code(),
463 SqlToken::Number => self.theme.sql_number.color_code(),
464 SqlToken::Comment => self.theme.sql_comment.color_code(),
465 SqlToken::Operator => self.theme.sql_operator.color_code(),
466 SqlToken::Identifier => self.theme.sql_identifier.color_code(),
467 SqlToken::Parameter => self.theme.info.color_code(),
468 SqlToken::Punctuation | SqlToken::Whitespace => String::new(),
469 }
470 }
471
472 #[must_use]
474 pub fn highlight(&self, sql: &str) -> String {
475 let segments = self.tokenize(sql);
476 let reset = "\x1b[0m";
477
478 segments
479 .iter()
480 .map(|seg| {
481 let color = self.color_for_token(seg.token);
482 if color.is_empty() {
483 seg.text.clone()
484 } else {
485 format!("{}{}{}", color, seg.text, reset)
486 }
487 })
488 .collect()
489 }
490
491 #[must_use]
493 pub fn plain(&self, sql: &str) -> String {
494 sql.to_string()
495 }
496
497 #[must_use]
499 pub fn format(&self, sql: &str) -> String {
500 let segments = self.tokenize(sql);
501 let mut result = String::new();
502 let mut indent = 0;
503 let indent_str = " ";
504 let mut newline_before = false;
505
506 for seg in segments {
507 let upper = seg.text.to_uppercase();
508
509 if matches!(
511 upper.as_str(),
512 "SELECT"
513 | "FROM"
514 | "WHERE"
515 | "ORDER"
516 | "GROUP"
517 | "HAVING"
518 | "LIMIT"
519 | "OFFSET"
520 | "SET"
521 | "VALUES"
522 | "RETURNING"
523 | "UNION"
524 | "INTERSECT"
525 | "EXCEPT"
526 ) {
527 if !result.is_empty() && !result.ends_with('\n') {
528 result.push('\n');
529 }
530 result.push_str(&indent_str.repeat(indent));
531 newline_before = false;
532 }
533
534 if matches!(upper.as_str(), "(" | "CASE") {
536 indent += 1;
537 }
538
539 if matches!(upper.as_str(), ")" | "END") {
541 indent = indent.saturating_sub(1);
542 }
543
544 if matches!(upper.as_str(), "AND" | "OR")
546 && !newline_before
547 && !result.ends_with('\n')
548 && !result.ends_with(' ')
549 {
550 result.push('\n');
551 result.push_str(&indent_str.repeat(indent + 1));
552 }
553
554 if matches!(
556 upper.as_str(),
557 "JOIN" | "LEFT" | "RIGHT" | "INNER" | "OUTER" | "CROSS" | "FULL"
558 ) {
559 if !result.ends_with('\n') && !result.ends_with(' ') && upper != "JOIN" {
560 } else if upper == "JOIN" && !result.ends_with(' ') {
562 result.push(' ');
563 }
564 }
565
566 if seg.token == SqlToken::Whitespace {
568 if !result.ends_with(' ') && !result.ends_with('\n') {
570 result.push(' ');
571 }
572 } else {
573 result.push_str(&seg.text);
574 }
575
576 newline_before = seg.token == SqlToken::Whitespace && seg.text.contains('\n');
577 }
578
579 result.trim().to_string()
580 }
581}
582
583impl Default for SqlHighlighter {
584 fn default() -> Self {
585 Self::new()
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
594 fn test_highlighter_new() {
595 let h = SqlHighlighter::new();
596 assert!(h.highlight("SELECT 1").contains("SELECT"));
597 }
598
599 #[test]
600 fn test_tokenize_select() {
601 let h = SqlHighlighter::new();
602 let segments = h.tokenize("SELECT * FROM users");
603
604 let tokens: Vec<SqlToken> = segments.iter().map(|s| s.token).collect();
605 assert!(tokens.contains(&SqlToken::Keyword));
606 assert!(tokens.contains(&SqlToken::Identifier));
607 }
608
609 #[test]
610 fn test_tokenize_string() {
611 let h = SqlHighlighter::new();
612 let segments = h.tokenize("SELECT 'hello'");
613
614 let has_string = segments
615 .iter()
616 .any(|s| s.token == SqlToken::String && s.text == "'hello'");
617 assert!(has_string);
618 }
619
620 #[test]
621 fn test_tokenize_number() {
622 let h = SqlHighlighter::new();
623 let segments = h.tokenize("SELECT 42, 3.14");
624
625 let numbers: Vec<&str> = segments
626 .iter()
627 .filter(|s| s.token == SqlToken::Number)
628 .map(|s| s.text.as_str())
629 .collect();
630 assert!(numbers.contains(&"42"));
631 assert!(numbers.contains(&"3.14"));
632 }
633
634 #[test]
635 fn test_tokenize_comment_single() {
636 let h = SqlHighlighter::new();
637 let segments = h.tokenize("SELECT 1 -- comment");
638
639 let has_comment = segments.iter().any(|s| s.token == SqlToken::Comment);
640 assert!(has_comment);
641 }
642
643 #[test]
644 fn test_tokenize_comment_multi() {
645 let h = SqlHighlighter::new();
646 let segments = h.tokenize("SELECT /* comment */ 1");
647
648 let has_comment = segments.iter().any(|s| s.token == SqlToken::Comment);
649 assert!(has_comment);
650 }
651
652 #[test]
653 fn test_tokenize_parameter_positional() {
654 let h = SqlHighlighter::new();
655 let segments = h.tokenize("SELECT * FROM users WHERE id = $1");
656
657 let has_param = segments
658 .iter()
659 .any(|s| s.token == SqlToken::Parameter && s.text == "$1");
660 assert!(has_param);
661 }
662
663 #[test]
664 fn test_tokenize_parameter_question() {
665 let h = SqlHighlighter::new();
666 let segments = h.tokenize("SELECT * FROM users WHERE id = ?");
667
668 let has_param = segments
669 .iter()
670 .any(|s| s.token == SqlToken::Parameter && s.text == "?");
671 assert!(has_param);
672 }
673
674 #[test]
675 fn test_tokenize_parameter_named() {
676 let h = SqlHighlighter::new();
677 let segments = h.tokenize("SELECT * FROM users WHERE id = :user_id");
678
679 let has_param = segments
680 .iter()
681 .any(|s| s.token == SqlToken::Parameter && s.text == ":user_id");
682 assert!(has_param);
683 }
684
685 #[test]
686 fn test_tokenize_operators() {
687 let h = SqlHighlighter::new();
688 let segments = h.tokenize("SELECT * FROM users WHERE age >= 18 AND active = true");
689
690 let has_ge = segments
691 .iter()
692 .any(|s| s.token == SqlToken::Operator && s.text == ">=");
693 let has_and = segments
694 .iter()
695 .any(|s| s.token == SqlToken::Operator && s.text.to_uppercase() == "AND");
696 assert!(has_ge);
697 assert!(has_and);
698 }
699
700 #[test]
701 fn test_tokenize_quoted_identifier() {
702 let h = SqlHighlighter::new();
703 let segments = h.tokenize("SELECT \"user-name\" FROM users");
704
705 let has_quoted = segments
706 .iter()
707 .any(|s| s.token == SqlToken::Identifier && s.text == "\"user-name\"");
708 assert!(has_quoted);
709 }
710
711 #[test]
712 fn test_highlight_produces_ansi() {
713 let h = SqlHighlighter::new();
714 let highlighted = h.highlight("SELECT 1");
715
716 assert!(highlighted.contains('\x1b'));
718 assert!(highlighted.contains("SELECT"));
720 assert!(highlighted.contains('1'));
721 }
722
723 #[test]
724 fn test_plain_no_change() {
725 let h = SqlHighlighter::new();
726 let sql = "SELECT * FROM users";
727 assert_eq!(h.plain(sql), sql);
728 }
729
730 #[test]
731 fn test_format_basic() {
732 let h = SqlHighlighter::new();
733 let sql = "SELECT * FROM users WHERE id = 1";
734 let formatted = h.format(sql);
735
736 assert!(formatted.contains("SELECT"));
738 assert!(formatted.contains("FROM"));
739 assert!(formatted.contains("WHERE"));
740 }
741
742 #[test]
743 fn test_is_keyword() {
744 assert!(SqlHighlighter::is_keyword("SELECT"));
745 assert!(SqlHighlighter::is_keyword("select"));
746 assert!(SqlHighlighter::is_keyword("Select"));
747 assert!(!SqlHighlighter::is_keyword("users"));
748 }
749
750 #[test]
751 fn test_is_operator_keyword() {
752 assert!(SqlHighlighter::is_operator_keyword("AND"));
753 assert!(SqlHighlighter::is_operator_keyword("or"));
754 assert!(!SqlHighlighter::is_operator_keyword("SELECT"));
755 }
756
757 #[test]
758 fn test_escaped_string() {
759 let h = SqlHighlighter::new();
760 let segments = h.tokenize("SELECT 'it''s'");
761
762 let string_seg = segments.iter().find(|s| s.token == SqlToken::String);
763 assert!(string_seg.is_some());
764 assert_eq!(string_seg.unwrap().text, "'it''s'");
765 }
766
767 #[test]
768 fn test_scientific_notation() {
769 let h = SqlHighlighter::new();
770 let segments = h.tokenize("SELECT 1.5e10");
771
772 let has_num = segments
773 .iter()
774 .any(|s| s.token == SqlToken::Number && s.text.contains('e'));
775 assert!(has_num);
776 }
777
778 #[test]
779 fn test_with_theme() {
780 let h = SqlHighlighter::with_theme(Theme::light());
781 let highlighted = h.highlight("SELECT 1");
782 assert!(highlighted.contains('\x1b'));
783 }
784
785 #[test]
786 fn test_builder_theme() {
787 let h = SqlHighlighter::new().theme(Theme::dark());
788 let highlighted = h.highlight("SELECT 1");
789 assert!(highlighted.contains('\x1b'));
790 }
791}