Skip to main content

sqlmodel_console/renderables/
sql_syntax.rs

1//! SQL syntax highlighting for query display.
2//!
3//! Provides syntax highlighting for SQL queries with theme-based coloring.
4//!
5//! # Example
6//!
7//! ```rust
8//! use sqlmodel_console::renderables::SqlHighlighter;
9//! use sqlmodel_console::Theme;
10//!
11//! let highlighter = SqlHighlighter::new();
12//! let sql = "SELECT * FROM users WHERE id = 1";
13//!
14//! // Get highlighted version
15//! let highlighted = highlighter.highlight(sql);
16//! println!("{}", highlighted);
17//!
18//! // Or plain version
19//! let plain = highlighter.plain(sql);
20//! println!("{}", plain);
21//! ```
22
23use crate::theme::Theme;
24
25/// SQL token types for syntax highlighting.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SqlToken {
28    /// SQL keyword (SELECT, FROM, WHERE, etc.)
29    Keyword,
30    /// String literal ('value')
31    String,
32    /// Numeric literal (42, 3.14)
33    Number,
34    /// SQL comment (-- comment or /* comment */)
35    Comment,
36    /// SQL operator (=, <, >, AND, OR, etc.)
37    Operator,
38    /// Identifier (table name, column name)
39    Identifier,
40    /// Punctuation (, ; ( ))
41    Punctuation,
42    /// Whitespace
43    Whitespace,
44    /// Parameter placeholder ($1, ?, :name)
45    Parameter,
46}
47
48/// A segment of SQL text with its token type.
49#[derive(Debug, Clone)]
50pub struct SqlSegment {
51    /// The text content
52    pub text: String,
53    /// The token type
54    pub token: SqlToken,
55}
56
57/// SQL syntax highlighter.
58///
59/// Tokenizes SQL queries and produces highlighted output using theme colors.
60#[derive(Debug, Clone)]
61pub struct SqlHighlighter {
62    /// Theme for coloring
63    theme: Theme,
64}
65
66impl SqlHighlighter {
67    /// SQL keywords to highlight.
68    const KEYWORDS: &'static [&'static str] = &[
69        // DML
70        "SELECT",
71        "INSERT",
72        "UPDATE",
73        "DELETE",
74        "FROM",
75        "WHERE",
76        "SET",
77        "VALUES",
78        "INTO",
79        "JOIN",
80        "LEFT",
81        "RIGHT",
82        "INNER",
83        "OUTER",
84        "FULL",
85        "CROSS",
86        "ON",
87        "USING",
88        "AS",
89        "DISTINCT",
90        "ALL",
91        "ORDER",
92        "BY",
93        "ASC",
94        "DESC",
95        "NULLS",
96        "FIRST",
97        "LAST",
98        "LIMIT",
99        "OFFSET",
100        "FETCH",
101        "NEXT",
102        "ROWS",
103        "ONLY",
104        "GROUP",
105        "HAVING",
106        "UNION",
107        "INTERSECT",
108        "EXCEPT",
109        "CASE",
110        "WHEN",
111        "THEN",
112        "ELSE",
113        "END",
114        "BETWEEN",
115        "IN",
116        "LIKE",
117        "ILIKE",
118        "SIMILAR",
119        "TO",
120        "EXISTS",
121        "ANY",
122        "SOME",
123        "RETURNING",
124        "WITH",
125        "RECURSIVE",
126        // DDL
127        "CREATE",
128        "ALTER",
129        "DROP",
130        "TRUNCATE",
131        "TABLE",
132        "INDEX",
133        "VIEW",
134        "SCHEMA",
135        "DATABASE",
136        "CONSTRAINT",
137        "PRIMARY",
138        "KEY",
139        "FOREIGN",
140        "REFERENCES",
141        "UNIQUE",
142        "CHECK",
143        "DEFAULT",
144        "NOT",
145        "NULL",
146        "AUTO_INCREMENT",
147        "AUTOINCREMENT",
148        "SERIAL",
149        "IF",
150        "CASCADE",
151        "RESTRICT",
152        // TCL
153        "BEGIN",
154        "COMMIT",
155        "ROLLBACK",
156        "SAVEPOINT",
157        "TRANSACTION",
158        "START",
159        "RELEASE",
160        // Types
161        "INTEGER",
162        "INT",
163        "BIGINT",
164        "SMALLINT",
165        "TINYINT",
166        "REAL",
167        "FLOAT",
168        "DOUBLE",
169        "PRECISION",
170        "DECIMAL",
171        "NUMERIC",
172        "VARCHAR",
173        "CHAR",
174        "TEXT",
175        "BLOB",
176        "BYTEA",
177        "BOOLEAN",
178        "BOOL",
179        "DATE",
180        "TIME",
181        "TIMESTAMP",
182        "INTERVAL",
183        "UUID",
184        "JSON",
185        "JSONB",
186        "ARRAY",
187        // Functions
188        "COUNT",
189        "SUM",
190        "AVG",
191        "MIN",
192        "MAX",
193        "COALESCE",
194        "NULLIF",
195        "CAST",
196        "EXTRACT",
197        "NOW",
198        "CURRENT_DATE",
199        "CURRENT_TIME",
200        "CURRENT_TIMESTAMP",
201        "LOWER",
202        "UPPER",
203        "TRIM",
204        "SUBSTRING",
205        "LENGTH",
206        "CONCAT",
207        "REPLACE",
208    ];
209
210    /// Create a new SQL highlighter with the default theme.
211    #[must_use]
212    pub fn new() -> Self {
213        Self {
214            theme: Theme::default(),
215        }
216    }
217
218    /// Create a new SQL highlighter with a specific theme.
219    #[must_use]
220    pub fn with_theme(theme: Theme) -> Self {
221        Self { theme }
222    }
223
224    /// Set the theme.
225    #[must_use]
226    pub fn theme(mut self, theme: Theme) -> Self {
227        self.theme = theme;
228        self
229    }
230
231    /// Check if a word is a SQL keyword.
232    fn is_keyword(word: &str) -> bool {
233        let upper = word.to_uppercase();
234        Self::KEYWORDS.contains(&upper.as_str())
235    }
236
237    /// Check if a word is a SQL operator keyword.
238    fn is_operator_keyword(word: &str) -> bool {
239        let upper = word.to_uppercase();
240        matches!(
241            upper.as_str(),
242            "AND" | "OR" | "NOT" | "IS" | "BETWEEN" | "LIKE" | "ILIKE" | "IN"
243        )
244    }
245
246    /// Tokenize SQL into segments.
247    #[must_use]
248    pub fn tokenize(&self, sql: &str) -> Vec<SqlSegment> {
249        let mut segments = Vec::new();
250        let chars: Vec<char> = sql.chars().collect();
251        let mut i = 0;
252
253        while i < chars.len() {
254            let c = chars[i];
255
256            // Whitespace
257            if c.is_whitespace() {
258                let start = i;
259                while i < chars.len() && chars[i].is_whitespace() {
260                    i += 1;
261                }
262                segments.push(SqlSegment {
263                    text: chars[start..i].iter().collect(),
264                    token: SqlToken::Whitespace,
265                });
266                continue;
267            }
268
269            // Single-line comment (-- ...)
270            if c == '-' && i + 1 < chars.len() && chars[i + 1] == '-' {
271                let start = i;
272                while i < chars.len() && chars[i] != '\n' {
273                    i += 1;
274                }
275                segments.push(SqlSegment {
276                    text: chars[start..i].iter().collect(),
277                    token: SqlToken::Comment,
278                });
279                continue;
280            }
281
282            // Multi-line comment (/* ... */)
283            if c == '/' && i + 1 < chars.len() && chars[i + 1] == '*' {
284                let start = i;
285                i += 2;
286                while i + 1 < chars.len() && !(chars[i] == '*' && chars[i + 1] == '/') {
287                    i += 1;
288                }
289                if i + 1 < chars.len() {
290                    i += 2; // Skip */
291                }
292                segments.push(SqlSegment {
293                    text: chars[start..i].iter().collect(),
294                    token: SqlToken::Comment,
295                });
296                continue;
297            }
298
299            // String literal ('...')
300            if c == '\'' {
301                let start = i;
302                i += 1;
303                while i < chars.len() {
304                    if chars[i] == '\'' {
305                        if i + 1 < chars.len() && chars[i + 1] == '\'' {
306                            i += 2; // Escaped quote
307                        } else {
308                            i += 1;
309                            break;
310                        }
311                    } else {
312                        i += 1;
313                    }
314                }
315                segments.push(SqlSegment {
316                    text: chars[start..i].iter().collect(),
317                    token: SqlToken::String,
318                });
319                continue;
320            }
321
322            // Double-quoted identifier ("...")
323            if c == '"' {
324                let start = i;
325                i += 1;
326                while i < chars.len() && chars[i] != '"' {
327                    i += 1;
328                }
329                if i < chars.len() {
330                    i += 1;
331                }
332                segments.push(SqlSegment {
333                    text: chars[start..i].iter().collect(),
334                    token: SqlToken::Identifier,
335                });
336                continue;
337            }
338
339            // Parameter placeholder ($1, $2, ?)
340            if c == '$' || c == '?' {
341                let start = i;
342                i += 1;
343                while i < chars.len() && chars[i].is_ascii_digit() {
344                    i += 1;
345                }
346                segments.push(SqlSegment {
347                    text: chars[start..i].iter().collect(),
348                    token: SqlToken::Parameter,
349                });
350                continue;
351            }
352
353            // Named parameter (:name)
354            if c == ':'
355                && i + 1 < chars.len()
356                && (chars[i + 1].is_alphabetic() || chars[i + 1] == '_')
357            {
358                let start = i;
359                i += 1;
360                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
361                    i += 1;
362                }
363                segments.push(SqlSegment {
364                    text: chars[start..i].iter().collect(),
365                    token: SqlToken::Parameter,
366                });
367                continue;
368            }
369
370            // Number
371            if c.is_ascii_digit()
372                || (c == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
373            {
374                let start = i;
375                let mut has_dot = c == '.';
376                i += 1;
377                while i < chars.len() {
378                    if chars[i].is_ascii_digit() {
379                        i += 1;
380                    } else if chars[i] == '.' && !has_dot {
381                        has_dot = true;
382                        i += 1;
383                    } else if chars[i] == 'e' || chars[i] == 'E' {
384                        i += 1;
385                        if i < chars.len() && (chars[i] == '+' || chars[i] == '-') {
386                            i += 1;
387                        }
388                    } else {
389                        break;
390                    }
391                }
392                segments.push(SqlSegment {
393                    text: chars[start..i].iter().collect(),
394                    token: SqlToken::Number,
395                });
396                continue;
397            }
398
399            // Identifier or keyword
400            if c.is_alphabetic() || c == '_' {
401                let start = i;
402                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
403                    i += 1;
404                }
405                let word: String = chars[start..i].iter().collect();
406                let token = if Self::is_operator_keyword(&word) {
407                    SqlToken::Operator
408                } else if Self::is_keyword(&word) {
409                    SqlToken::Keyword
410                } else {
411                    SqlToken::Identifier
412                };
413                segments.push(SqlSegment { text: word, token });
414                continue;
415            }
416
417            // Operators and punctuation
418            if matches!(c, '=' | '<' | '>' | '!' | '+' | '-' | '*' | '/' | '%' | '|') {
419                let start = i;
420                i += 1;
421                // Handle multi-char operators
422                if i < chars.len() {
423                    let next = chars[i];
424                    let is_two_char_op =
425                        matches!((c, next), ('<', '>' | '=') | ('>' | '!', '=') | ('|', '|'));
426                    if is_two_char_op {
427                        i += 1;
428                    }
429                }
430                segments.push(SqlSegment {
431                    text: chars[start..i].iter().collect(),
432                    token: SqlToken::Operator,
433                });
434                continue;
435            }
436
437            // Punctuation
438            if matches!(c, '(' | ')' | ',' | ';' | '.') {
439                segments.push(SqlSegment {
440                    text: c.to_string(),
441                    token: SqlToken::Punctuation,
442                });
443                i += 1;
444                continue;
445            }
446
447            // Unknown - treat as identifier
448            segments.push(SqlSegment {
449                text: c.to_string(),
450                token: SqlToken::Identifier,
451            });
452            i += 1;
453        }
454
455        segments
456    }
457
458    /// Get the ANSI color code for a token type.
459    fn color_for_token(&self, token: SqlToken) -> String {
460        match token {
461            SqlToken::Keyword => self.theme.sql_keyword.color_code(),
462            SqlToken::String => self.theme.sql_string.color_code(),
463            SqlToken::Number => self.theme.sql_number.color_code(),
464            SqlToken::Comment => self.theme.sql_comment.color_code(),
465            SqlToken::Operator => self.theme.sql_operator.color_code(),
466            SqlToken::Identifier => self.theme.sql_identifier.color_code(),
467            SqlToken::Parameter => self.theme.info.color_code(),
468            SqlToken::Punctuation | SqlToken::Whitespace => String::new(),
469        }
470    }
471
472    /// Highlight SQL with ANSI colors.
473    #[must_use]
474    pub fn highlight(&self, sql: &str) -> String {
475        let segments = self.tokenize(sql);
476        let reset = "\x1b[0m";
477
478        segments
479            .iter()
480            .map(|seg| {
481                let color = self.color_for_token(seg.token);
482                if color.is_empty() {
483                    seg.text.clone()
484                } else {
485                    format!("{}{}{}", color, seg.text, reset)
486                }
487            })
488            .collect()
489    }
490
491    /// Return plain SQL (no highlighting).
492    #[must_use]
493    pub fn plain(&self, sql: &str) -> String {
494        sql.to_string()
495    }
496
497    /// Format SQL with indentation (basic pretty-print).
498    #[must_use]
499    pub fn format(&self, sql: &str) -> String {
500        let segments = self.tokenize(sql);
501        let mut result = String::new();
502        let mut indent = 0;
503        let indent_str = "  ";
504        let mut newline_before = false;
505
506        for seg in segments {
507            let upper = seg.text.to_uppercase();
508
509            // Keywords that start a new line with same indentation
510            if matches!(
511                upper.as_str(),
512                "SELECT"
513                    | "FROM"
514                    | "WHERE"
515                    | "ORDER"
516                    | "GROUP"
517                    | "HAVING"
518                    | "LIMIT"
519                    | "OFFSET"
520                    | "SET"
521                    | "VALUES"
522                    | "RETURNING"
523                    | "UNION"
524                    | "INTERSECT"
525                    | "EXCEPT"
526            ) {
527                if !result.is_empty() && !result.ends_with('\n') {
528                    result.push('\n');
529                }
530                result.push_str(&indent_str.repeat(indent));
531                newline_before = false;
532            }
533
534            // Keywords that increase indentation
535            if matches!(upper.as_str(), "(" | "CASE") {
536                indent += 1;
537            }
538
539            // Keywords that decrease indentation
540            if matches!(upper.as_str(), ")" | "END") {
541                indent = indent.saturating_sub(1);
542            }
543
544            // Add keyword that needs newline before
545            if matches!(upper.as_str(), "AND" | "OR")
546                && !newline_before
547                && !result.ends_with('\n')
548                && !result.ends_with(' ')
549            {
550                result.push('\n');
551                result.push_str(&indent_str.repeat(indent + 1));
552            }
553
554            // Handle JOIN keywords
555            if matches!(
556                upper.as_str(),
557                "JOIN" | "LEFT" | "RIGHT" | "INNER" | "OUTER" | "CROSS" | "FULL"
558            ) {
559                if !result.ends_with('\n') && !result.ends_with(' ') && upper != "JOIN" {
560                    // Keep LEFT/RIGHT/etc with JOIN on same line
561                } else if upper == "JOIN" && !result.ends_with(' ') {
562                    result.push(' ');
563                }
564            }
565
566            // Append the text
567            if seg.token == SqlToken::Whitespace {
568                // Normalize whitespace
569                if !result.ends_with(' ') && !result.ends_with('\n') {
570                    result.push(' ');
571                }
572            } else {
573                result.push_str(&seg.text);
574            }
575
576            newline_before = seg.token == SqlToken::Whitespace && seg.text.contains('\n');
577        }
578
579        result.trim().to_string()
580    }
581}
582
583impl Default for SqlHighlighter {
584    fn default() -> Self {
585        Self::new()
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[test]
594    fn test_highlighter_new() {
595        let h = SqlHighlighter::new();
596        assert!(h.highlight("SELECT 1").contains("SELECT"));
597    }
598
599    #[test]
600    fn test_tokenize_select() {
601        let h = SqlHighlighter::new();
602        let segments = h.tokenize("SELECT * FROM users");
603
604        let tokens: Vec<SqlToken> = segments.iter().map(|s| s.token).collect();
605        assert!(tokens.contains(&SqlToken::Keyword));
606        assert!(tokens.contains(&SqlToken::Identifier));
607    }
608
609    #[test]
610    fn test_tokenize_string() {
611        let h = SqlHighlighter::new();
612        let segments = h.tokenize("SELECT 'hello'");
613
614        let has_string = segments
615            .iter()
616            .any(|s| s.token == SqlToken::String && s.text == "'hello'");
617        assert!(has_string);
618    }
619
620    #[test]
621    fn test_tokenize_number() {
622        let h = SqlHighlighter::new();
623        let segments = h.tokenize("SELECT 42, 3.14");
624
625        let numbers: Vec<&str> = segments
626            .iter()
627            .filter(|s| s.token == SqlToken::Number)
628            .map(|s| s.text.as_str())
629            .collect();
630        assert!(numbers.contains(&"42"));
631        assert!(numbers.contains(&"3.14"));
632    }
633
634    #[test]
635    fn test_tokenize_comment_single() {
636        let h = SqlHighlighter::new();
637        let segments = h.tokenize("SELECT 1 -- comment");
638
639        let has_comment = segments.iter().any(|s| s.token == SqlToken::Comment);
640        assert!(has_comment);
641    }
642
643    #[test]
644    fn test_tokenize_comment_multi() {
645        let h = SqlHighlighter::new();
646        let segments = h.tokenize("SELECT /* comment */ 1");
647
648        let has_comment = segments.iter().any(|s| s.token == SqlToken::Comment);
649        assert!(has_comment);
650    }
651
652    #[test]
653    fn test_tokenize_parameter_positional() {
654        let h = SqlHighlighter::new();
655        let segments = h.tokenize("SELECT * FROM users WHERE id = $1");
656
657        let has_param = segments
658            .iter()
659            .any(|s| s.token == SqlToken::Parameter && s.text == "$1");
660        assert!(has_param);
661    }
662
663    #[test]
664    fn test_tokenize_parameter_question() {
665        let h = SqlHighlighter::new();
666        let segments = h.tokenize("SELECT * FROM users WHERE id = ?");
667
668        let has_param = segments
669            .iter()
670            .any(|s| s.token == SqlToken::Parameter && s.text == "?");
671        assert!(has_param);
672    }
673
674    #[test]
675    fn test_tokenize_parameter_named() {
676        let h = SqlHighlighter::new();
677        let segments = h.tokenize("SELECT * FROM users WHERE id = :user_id");
678
679        let has_param = segments
680            .iter()
681            .any(|s| s.token == SqlToken::Parameter && s.text == ":user_id");
682        assert!(has_param);
683    }
684
685    #[test]
686    fn test_tokenize_operators() {
687        let h = SqlHighlighter::new();
688        let segments = h.tokenize("SELECT * FROM users WHERE age >= 18 AND active = true");
689
690        let has_ge = segments
691            .iter()
692            .any(|s| s.token == SqlToken::Operator && s.text == ">=");
693        let has_and = segments
694            .iter()
695            .any(|s| s.token == SqlToken::Operator && s.text.to_uppercase() == "AND");
696        assert!(has_ge);
697        assert!(has_and);
698    }
699
700    #[test]
701    fn test_tokenize_quoted_identifier() {
702        let h = SqlHighlighter::new();
703        let segments = h.tokenize("SELECT \"user-name\" FROM users");
704
705        let has_quoted = segments
706            .iter()
707            .any(|s| s.token == SqlToken::Identifier && s.text == "\"user-name\"");
708        assert!(has_quoted);
709    }
710
711    #[test]
712    fn test_highlight_produces_ansi() {
713        let h = SqlHighlighter::new();
714        let highlighted = h.highlight("SELECT 1");
715
716        // Should contain ANSI escape codes
717        assert!(highlighted.contains('\x1b'));
718        // Should contain the text
719        assert!(highlighted.contains("SELECT"));
720        assert!(highlighted.contains('1'));
721    }
722
723    #[test]
724    fn test_plain_no_change() {
725        let h = SqlHighlighter::new();
726        let sql = "SELECT * FROM users";
727        assert_eq!(h.plain(sql), sql);
728    }
729
730    #[test]
731    fn test_format_basic() {
732        let h = SqlHighlighter::new();
733        let sql = "SELECT * FROM users WHERE id = 1";
734        let formatted = h.format(sql);
735
736        // Should have newlines for major clauses
737        assert!(formatted.contains("SELECT"));
738        assert!(formatted.contains("FROM"));
739        assert!(formatted.contains("WHERE"));
740    }
741
742    #[test]
743    fn test_is_keyword() {
744        assert!(SqlHighlighter::is_keyword("SELECT"));
745        assert!(SqlHighlighter::is_keyword("select"));
746        assert!(SqlHighlighter::is_keyword("Select"));
747        assert!(!SqlHighlighter::is_keyword("users"));
748    }
749
750    #[test]
751    fn test_is_operator_keyword() {
752        assert!(SqlHighlighter::is_operator_keyword("AND"));
753        assert!(SqlHighlighter::is_operator_keyword("or"));
754        assert!(!SqlHighlighter::is_operator_keyword("SELECT"));
755    }
756
757    #[test]
758    fn test_escaped_string() {
759        let h = SqlHighlighter::new();
760        let segments = h.tokenize("SELECT 'it''s'");
761
762        let string_seg = segments.iter().find(|s| s.token == SqlToken::String);
763        assert!(string_seg.is_some());
764        assert_eq!(string_seg.unwrap().text, "'it''s'");
765    }
766
767    #[test]
768    fn test_scientific_notation() {
769        let h = SqlHighlighter::new();
770        let segments = h.tokenize("SELECT 1.5e10");
771
772        let has_num = segments
773            .iter()
774            .any(|s| s.token == SqlToken::Number && s.text.contains('e'));
775        assert!(has_num);
776    }
777
778    #[test]
779    fn test_with_theme() {
780        let h = SqlHighlighter::with_theme(Theme::light());
781        let highlighted = h.highlight("SELECT 1");
782        assert!(highlighted.contains('\x1b'));
783    }
784
785    #[test]
786    fn test_builder_theme() {
787        let h = SqlHighlighter::new().theme(Theme::dark());
788        let highlighted = h.highlight("SELECT 1");
789        assert!(highlighted.contains('\x1b'));
790    }
791}