Skip to main content

systemprompt_database/services/
schema_linter.rs

1//! Declarative-schema linter.
2//!
3//! Walks a SQL script using the same lex states as
4//! [`crate::services::executor::SqlExecutor::parse_sql_statements`] (single
5//! quote, dollar quote, line/block comment) so it inspects only top-level
6//! tokens. Each top-level statement is classified by leading keywords:
7//!
8//! - **Allowed**: `CREATE TABLE [IF NOT EXISTS]`, `CREATE [UNIQUE] INDEX [IF
9//!   NOT EXISTS]`, `CREATE [OR REPLACE] FUNCTION`, `CREATE [OR REPLACE] VIEW`,
10//!   `CREATE [OR REPLACE] TRIGGER`, `CREATE TYPE`, `CREATE EXTENSION IF NOT
11//!   EXISTS`, `COMMENT ON`.
12//! - **Rejected**: `ALTER`, `DROP`, top-level `DO $$ … $$`, `UPDATE`, `INSERT`,
13//!   `DELETE`, `TRUNCATE`, `GRANT`, `REVOKE`, anything containing `RENAME`.
14//! - **Naked `CREATE TABLE foo (…)`** without `IF NOT EXISTS` is permitted but
15//!   emitted as an informational warning (still reported as a [`LintError`]
16//!   with [`LintSeverity::Warning`]).
17//!
18//! The lexer mirrors the splitter rather than calling into it because the
19//! linter needs byte offsets — preserved as `(line, column)` — to surface
20//! useful error messages.
21
22use std::fmt;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum LintSeverity {
26    Error,
27    Warning,
28}
29
30impl fmt::Display for LintSeverity {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            Self::Error => f.write_str("error"),
34            Self::Warning => f.write_str("warning"),
35        }
36    }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct LintError {
41    pub line: u32,
42    pub column: u32,
43    pub severity: LintSeverity,
44    pub message: String,
45    pub source: String,
46}
47
48impl fmt::Display for LintError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(
51            f,
52            "{}:{}:{}: {}: {}",
53            self.source, self.line, self.column, self.severity, self.message
54        )
55    }
56}
57
58/// Lint a single declarative schema file. Returns the list of violations,
59/// or `Ok(())` if the script is purely declarative.
60///
61/// `source` is the label included in error messages (typically the schema
62/// table name or the file path).
63pub fn lint_declarative_schema(sql: &str, source: &str) -> Result<(), Vec<LintError>> {
64    let statements = split_top_level_statements(sql, source)?;
65    let mut errors = Vec::new();
66    for stmt in &statements {
67        if let Some(err) = classify(stmt, source) {
68            errors.push(err);
69        }
70    }
71    if errors.iter().any(|e| e.severity == LintSeverity::Error) {
72        return Err(errors);
73    }
74    Ok(())
75}
76
77#[derive(Debug, Clone)]
78struct TopStatement {
79    text: String,
80    start_line: u32,
81    start_column: u32,
82}
83
84enum LexState {
85    Normal,
86    SingleQuote,
87    DollarQuote(String),
88    LineComment,
89    BlockComment(u32),
90}
91
92fn dollar_tag_end(bytes: &[u8], start: usize) -> Option<usize> {
93    debug_assert_eq!(bytes[start], b'$');
94    let mut j = start + 1;
95    while j < bytes.len() {
96        let c = bytes[j];
97        if c == b'$' {
98            return Some(j);
99        }
100        if !(c.is_ascii_alphanumeric() || c == b'_') {
101            return None;
102        }
103        j += 1;
104    }
105    None
106}
107
108fn split_top_level_statements(
109    sql: &str,
110    source: &str,
111) -> Result<Vec<TopStatement>, Vec<LintError>> {
112    let bytes = sql.as_bytes();
113    let mut statements: Vec<TopStatement> = Vec::new();
114    let mut state = LexState::Normal;
115    let mut start = 0usize;
116    let mut i = 0usize;
117    let mut start_line: u32 = 1;
118    let mut start_col: u32 = 1;
119    let mut line: u32 = 1;
120    let mut col: u32 = 1;
121    let mut stmt_line: u32 = 1;
122    let mut stmt_col: u32 = 1;
123    let mut has_content = false;
124
125    while i < bytes.len() {
126        let b = bytes[i];
127        match &mut state {
128            LexState::Normal => match b {
129                b'\'' => {
130                    if !has_content {
131                        stmt_line = line;
132                        stmt_col = col;
133                    }
134                    has_content = true;
135                    state = LexState::SingleQuote;
136                    advance(&mut i, &mut line, &mut col, b);
137                },
138                b'-' if bytes.get(i + 1) == Some(&b'-') => {
139                    state = LexState::LineComment;
140                    advance(&mut i, &mut line, &mut col, b);
141                    advance(&mut i, &mut line, &mut col, b'-');
142                },
143                b'/' if bytes.get(i + 1) == Some(&b'*') => {
144                    state = LexState::BlockComment(1);
145                    advance(&mut i, &mut line, &mut col, b);
146                    advance(&mut i, &mut line, &mut col, b'*');
147                },
148                b'$' => {
149                    if !has_content {
150                        stmt_line = line;
151                        stmt_col = col;
152                    }
153                    has_content = true;
154                    if let Some(tag_end) = dollar_tag_end(bytes, i) {
155                        let tag = sql[i..=tag_end].to_string();
156                        let advance_by = tag_end - i + 1;
157                        for _ in 0..advance_by {
158                            advance(&mut i, &mut line, &mut col, b'$');
159                        }
160                        state = LexState::DollarQuote(tag);
161                    } else {
162                        advance(&mut i, &mut line, &mut col, b);
163                    }
164                },
165                b';' => {
166                    if has_content {
167                        let text = sql[start..i].trim().to_string();
168                        if !text.is_empty() {
169                            statements.push(TopStatement {
170                                text,
171                                start_line: stmt_line,
172                                start_column: stmt_col,
173                            });
174                        }
175                    }
176                    has_content = false;
177                    advance(&mut i, &mut line, &mut col, b);
178                    start = i;
179                    start_line = line;
180                    start_col = col;
181                },
182                _ => {
183                    if !b.is_ascii_whitespace() {
184                        if !has_content {
185                            stmt_line = line;
186                            stmt_col = col;
187                        }
188                        has_content = true;
189                    }
190                    advance(&mut i, &mut line, &mut col, b);
191                },
192            },
193            LexState::SingleQuote => {
194                if b == b'\'' {
195                    if bytes.get(i + 1) == Some(&b'\'') {
196                        advance(&mut i, &mut line, &mut col, b);
197                        advance(&mut i, &mut line, &mut col, b'\'');
198                    } else {
199                        state = LexState::Normal;
200                        advance(&mut i, &mut line, &mut col, b);
201                    }
202                } else {
203                    advance(&mut i, &mut line, &mut col, b);
204                }
205            },
206            LexState::DollarQuote(tag) => {
207                let tag_bytes = tag.as_bytes();
208                if i + tag_bytes.len() <= bytes.len() && &bytes[i..i + tag_bytes.len()] == tag_bytes
209                {
210                    for _ in 0..tag_bytes.len() {
211                        advance(&mut i, &mut line, &mut col, b'$');
212                    }
213                    state = LexState::Normal;
214                } else {
215                    advance(&mut i, &mut line, &mut col, b);
216                }
217            },
218            LexState::LineComment => {
219                if b == b'\n' {
220                    state = LexState::Normal;
221                }
222                advance(&mut i, &mut line, &mut col, b);
223            },
224            LexState::BlockComment(depth) => {
225                if b == b'/' && bytes.get(i + 1) == Some(&b'*') {
226                    *depth += 1;
227                    advance(&mut i, &mut line, &mut col, b);
228                    advance(&mut i, &mut line, &mut col, b'*');
229                } else if b == b'*' && bytes.get(i + 1) == Some(&b'/') {
230                    *depth -= 1;
231                    let zero = *depth == 0;
232                    advance(&mut i, &mut line, &mut col, b);
233                    advance(&mut i, &mut line, &mut col, b'/');
234                    if zero {
235                        state = LexState::Normal;
236                    }
237                } else {
238                    advance(&mut i, &mut line, &mut col, b);
239                }
240            },
241        }
242    }
243
244    match state {
245        LexState::Normal | LexState::LineComment => {
246            if has_content {
247                let text = sql[start..].trim().to_string();
248                if !text.is_empty() {
249                    statements.push(TopStatement {
250                        text,
251                        start_line: stmt_line,
252                        start_column: stmt_col,
253                    });
254                }
255            }
256            Ok(statements)
257        },
258        LexState::SingleQuote => Err(vec![LintError {
259            line: start_line,
260            column: start_col,
261            severity: LintSeverity::Error,
262            message: "unterminated string literal".into(),
263            source: source.to_string(),
264        }]),
265        LexState::DollarQuote(tag) => Err(vec![LintError {
266            line: start_line,
267            column: start_col,
268            severity: LintSeverity::Error,
269            message: format!("unterminated dollar-quoted string: {tag}"),
270            source: source.to_string(),
271        }]),
272        LexState::BlockComment(_) => Err(vec![LintError {
273            line: start_line,
274            column: start_col,
275            severity: LintSeverity::Error,
276            message: "unterminated block comment".into(),
277            source: source.to_string(),
278        }]),
279    }
280}
281
282fn advance(i: &mut usize, line: &mut u32, col: &mut u32, b: u8) {
283    *i += 1;
284    if b == b'\n' {
285        *line += 1;
286        *col = 1;
287    } else {
288        *col += 1;
289    }
290}
291
292fn classify(stmt: &TopStatement, source: &str) -> Option<LintError> {
293    let stripped = strip_sql_comments(&stmt.text);
294    let upper = uppercase_keywords(&stripped);
295    let tokens: Vec<&str> = upper.split_whitespace().collect();
296    if tokens.is_empty() {
297        return None;
298    }
299
300    let leading = tokens[0];
301
302    let reject = |reason: &str| LintError {
303        line: stmt.start_line,
304        column: stmt.start_column,
305        severity: LintSeverity::Error,
306        message: format!(
307            "imperative SQL in declarative schema: {reason} — move to \
308             schema/migrations/NNN_<name>.sql"
309        ),
310        source: source.to_string(),
311    };
312
313    match leading {
314        "ALTER" => return Some(reject("ALTER")),
315        "DROP" => return Some(reject("DROP")),
316        "UPDATE" => return Some(reject("UPDATE")),
317        "INSERT" => return Some(reject("INSERT")),
318        "DELETE" => return Some(reject("DELETE")),
319        "TRUNCATE" => return Some(reject("TRUNCATE")),
320        "GRANT" => return Some(reject("GRANT")),
321        "REVOKE" => return Some(reject("REVOKE")),
322        "DO" => return Some(reject("DO $$ block")),
323        _ => {},
324    }
325
326    if leading == "CREATE" {
327        return classify_create(&tokens, stmt, source);
328    }
329
330    if leading == "COMMENT" && tokens.get(1) == Some(&"ON") {
331        return None;
332    }
333
334    if leading == "SELECT" {
335        return Some(LintError {
336            line: stmt.start_line,
337            column: stmt.start_column,
338            severity: LintSeverity::Error,
339            message: "imperative SQL in declarative schema: SELECT — move to \
340                      schema/migrations/NNN_<name>.sql"
341                .into(),
342            source: source.to_string(),
343        });
344    }
345
346    None
347}
348
349fn classify_create(tokens: &[&str], stmt: &TopStatement, source: &str) -> Option<LintError> {
350    let mut idx = 1;
351
352    if tokens.get(idx) == Some(&"OR") && tokens.get(idx + 1) == Some(&"REPLACE") {
353        idx += 2;
354    }
355
356    if tokens.get(idx) == Some(&"UNIQUE") {
357        idx += 1;
358    }
359
360    let kind = match tokens.get(idx) {
361        Some(k) => *k,
362        None => return None,
363    };
364    idx += 1;
365
366    let has_if_not_exists = tokens.get(idx) == Some(&"IF")
367        && tokens.get(idx + 1) == Some(&"NOT")
368        && tokens.get(idx + 2) == Some(&"EXISTS");
369
370    match kind {
371        "TABLE" => {
372            if !has_if_not_exists {
373                return Some(LintError {
374                    line: stmt.start_line,
375                    column: stmt.start_column,
376                    severity: LintSeverity::Warning,
377                    message: "CREATE TABLE without IF NOT EXISTS — add IF NOT EXISTS for \
378                              idempotency"
379                        .into(),
380                    source: source.to_string(),
381                });
382            }
383            None
384        },
385        "EXTENSION" => {
386            if !has_if_not_exists {
387                return Some(LintError {
388                    line: stmt.start_line,
389                    column: stmt.start_column,
390                    severity: LintSeverity::Warning,
391                    message: "CREATE EXTENSION without IF NOT EXISTS".into(),
392                    source: source.to_string(),
393                });
394            }
395            None
396        },
397        _ => None,
398    }
399}
400
401fn strip_sql_comments(text: &str) -> String {
402    let bytes = text.as_bytes();
403    let mut out = String::with_capacity(text.len());
404    let mut i = 0;
405    let mut in_single = false;
406    let mut in_dollar: Option<String> = None;
407    while i < bytes.len() {
408        let b = bytes[i];
409        if let Some(tag) = &in_dollar {
410            let tag_b = tag.as_bytes();
411            if i + tag_b.len() <= bytes.len() && &bytes[i..i + tag_b.len()] == tag_b {
412                out.push_str(tag);
413                i += tag_b.len();
414                in_dollar = None;
415            } else {
416                out.push(b as char);
417                i += 1;
418            }
419            continue;
420        }
421        if in_single {
422            out.push(b as char);
423            if b == b'\'' {
424                if bytes.get(i + 1) == Some(&b'\'') {
425                    out.push('\'');
426                    i += 2;
427                    continue;
428                }
429                in_single = false;
430            }
431            i += 1;
432            continue;
433        }
434        if b == b'\'' {
435            in_single = true;
436            out.push('\'');
437            i += 1;
438            continue;
439        }
440        if b == b'$' {
441            if let Some(end) = dollar_tag_end(bytes, i) {
442                let tag = text[i..=end].to_string();
443                out.push_str(&tag);
444                i = end + 1;
445                in_dollar = Some(tag);
446                continue;
447            }
448        }
449        if b == b'-' && bytes.get(i + 1) == Some(&b'-') {
450            while i < bytes.len() && bytes[i] != b'\n' {
451                i += 1;
452            }
453            continue;
454        }
455        if b == b'/' && bytes.get(i + 1) == Some(&b'*') {
456            let mut depth = 1u32;
457            i += 2;
458            while i < bytes.len() && depth > 0 {
459                if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'*') {
460                    depth += 1;
461                    i += 2;
462                } else if bytes[i] == b'*' && bytes.get(i + 1) == Some(&b'/') {
463                    depth -= 1;
464                    i += 2;
465                } else {
466                    i += 1;
467                }
468            }
469            continue;
470        }
471        out.push(b as char);
472        i += 1;
473    }
474    out
475}
476
477fn uppercase_keywords(text: &str) -> String {
478    let mut out = String::with_capacity(text.len());
479    let mut in_string = false;
480    let mut in_dollar = false;
481    let bytes = text.as_bytes();
482    let mut i = 0;
483    while i < bytes.len() {
484        let b = bytes[i];
485        if !in_string && !in_dollar && b == b'$' {
486            if let Some(end) = dollar_tag_end(bytes, i) {
487                out.push_str(&text[i..=end]);
488                i = end + 1;
489                in_dollar = true;
490                continue;
491            }
492        }
493        if in_dollar && b == b'$' {
494            if let Some(end) = dollar_tag_end(bytes, i) {
495                out.push_str(&text[i..=end]);
496                i = end + 1;
497                in_dollar = false;
498                continue;
499            }
500        }
501        if in_dollar {
502            out.push(b as char);
503            i += 1;
504            continue;
505        }
506        if b == b'\'' {
507            in_string = !in_string;
508            out.push('\'');
509            i += 1;
510            continue;
511        }
512        if in_string {
513            out.push(b as char);
514            i += 1;
515            continue;
516        }
517        out.push(b.to_ascii_uppercase() as char);
518        i += 1;
519    }
520    out
521}