Skip to main content

systemprompt_database/services/schema_linter/
mod.rs

1//! Declarative-schema linter.
2//!
3//! Parses each schema with [`pg_query`] (the actual `PostgreSQL` parser,
4//! exposed as a protobuf AST) and walks top-level statements. Classification is
5//! by AST node variant rather than keyword tokens, so identifier-equal strings
6//! such as a column literally named `alter` do not produce false positives,
7//! and dollar-quoted PL/pgSQL bodies are skipped at the parser level.
8//!
9//! ## Allowed top-level statements
10//!
11//! - `CreateStmt` — `CREATE TABLE`
12//! - `IndexStmt` — `CREATE [UNIQUE] INDEX`
13//! - `CreateFunctionStmt`
14//! - `ViewStmt` — `CREATE [OR REPLACE] VIEW`
15//! - `CreateTrigStmt`
16//! - `CompositeTypeStmt` — `CREATE TYPE … AS (…)`
17//! - `CreateEnumStmt` — `CREATE TYPE … AS ENUM`
18//! - `CreateExtensionStmt`
19//! - `CommentStmt` — `COMMENT ON …`
20//! - `DropStmt` — only `DROP VIEW`/`MATERIALIZED VIEW`/`INDEX`/`TRIGGER … IF
21//!   EXISTS`. These objects are stateless derived artifacts: dropping one loses
22//!   no data and the sibling `CREATE …` statement rebuilds it, so the pair
23//!   stays idempotent. `DROP TABLE`/`DROP COLUMN` remain rejected.
24//!
25//! ## Rejected top-level statements
26//!
27//! - `AlterTableStmt`
28//! - `DropStmt` — except the stateless-object carve-out above
29//! - `InsertStmt` / `UpdateStmt` / `DeleteStmt` / `TruncateStmt`
30//! - `GrantStmt` / `RevokeStmt`
31//! - `RenameStmt` — any object rename
32//! - `DoStmt` — anonymous `DO $$ … $$` blocks
33//! - Any bare `SELECT`/`COPY`/imperative statement
34//!
35//! ## Semantic checks
36//!
37//! For statements that reference columns of a table defined elsewhere in the
38//! same input (`CREATE INDEX`, `CREATE VIEW`), the linter resolves the
39//! `(table, column)` pair against an in-input schema graph built from sibling
40//! `CREATE TABLE` nodes. References to tables that are not declared in the
41//! same input (e.g. cross-extension `REFERENCES`) are intentionally not
42//! resolved — the parser sees those as forward references the database itself
43//! validates at apply-time.
44//!
45//! Column resolution does not descend into:
46//!
47//! - PL/pgSQL function bodies (resolved by Postgres at function call time)
48//! - `CHECK` constraint expressions (resolved by Postgres at table creation)
49//! - Trigger function bodies
50//!
51//! These are deferred so the linter behaves identically to the database for
52//! anything it cannot statically prove, avoiding false positives on
53//! late-bound names.
54
55mod classify;
56mod columns;
57mod location;
58
59use std::fmt;
60
61use pg_query::protobuf::node::Node;
62
63use classify::{imperative_reason, warn_create_table_missing_if_not_exists};
64use columns::{TableDef, check_index_columns, check_view_columns, collect_create_stmt};
65use location::{LineIndex, StmtLoc, stmt_start_offset};
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LintSeverity {
69    Error,
70    Warning,
71}
72
73impl fmt::Display for LintSeverity {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        match self {
76            Self::Error => f.write_str("error"),
77            Self::Warning => f.write_str("warning"),
78        }
79    }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct LintError {
84    pub line: u32,
85    pub column: u32,
86    pub severity: LintSeverity,
87    pub message: String,
88    pub source: String,
89}
90
91impl fmt::Display for LintError {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(
94            f,
95            "{}:{}:{}: {}: {}",
96            self.source, self.line, self.column, self.severity, self.message
97        )
98    }
99}
100
101/// The single source of truth for which tables an extension *owns*.
102///
103/// Ownership is derived from the declarative schema, never hand-authored. A
104/// parse failure yields an empty list — the linter reports the parse error
105/// separately.
106#[must_use]
107pub fn created_table_names(sql: &str) -> Vec<String> {
108    let Ok(parsed) = pg_query::parse(sql) else {
109        return Vec::new();
110    };
111    parsed
112        .protobuf
113        .stmts
114        .iter()
115        .filter_map(|raw| match raw.stmt.as_ref()?.node.as_ref()? {
116            Node::CreateStmt(create) => collect_create_stmt(create).map(|t| t.name().to_owned()),
117            _ => None,
118        })
119        .collect()
120}
121
122/// `source` is not read; it is the label stamped into error messages (typically
123/// the schema table name or the file path).
124pub fn lint_declarative_schema(sql: &str, source: &str) -> Result<(), Vec<LintError>> {
125    let parsed = match pg_query::parse(sql) {
126        Ok(p) => p,
127        Err(e) => {
128            return Err(vec![LintError {
129                line: 1,
130                column: 1,
131                severity: LintSeverity::Error,
132                message: format!("SQL parse failed: {e}"),
133                source: source.to_owned(),
134            }]);
135        },
136    };
137
138    let line_index = LineIndex::new(sql);
139    let stmts = &parsed.protobuf.stmts;
140    let (tables, mut errors) = classify_pass(stmts, sql, &line_index, source);
141    errors.extend(column_ref_pass(stmts, sql, &line_index, &tables, source));
142
143    if errors.iter().any(|e| e.severity == LintSeverity::Error) {
144        return Err(errors);
145    }
146    Ok(())
147}
148
149fn classify_pass(
150    stmts: &[pg_query::protobuf::RawStmt],
151    sql: &str,
152    line_index: &LineIndex,
153    source: &str,
154) -> (Vec<TableDef>, Vec<LintError>) {
155    let mut errors: Vec<LintError> = Vec::new();
156    let mut tables: Vec<TableDef> = Vec::new();
157
158    for raw in stmts {
159        let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
160        let (line, col) = line_index.position(location);
161        let loc = StmtLoc { line, col, source };
162
163        let Some(stmt) = raw.stmt.as_ref() else {
164            continue;
165        };
166        let Some(node) = stmt.node.as_ref() else {
167            continue;
168        };
169
170        match node {
171            Node::CreateStmt(create) => {
172                if let Some(table) = collect_create_stmt(create) {
173                    tables.push(table);
174                }
175                if let Some(warn) = warn_create_table_missing_if_not_exists(create, &loc) {
176                    errors.push(warn);
177                }
178            },
179            Node::IndexStmt(_)
180            | Node::CreateFunctionStmt(_)
181            | Node::ViewStmt(_)
182            | Node::CreateTrigStmt(_)
183            | Node::CompositeTypeStmt(_)
184            | Node::CreateEnumStmt(_)
185            | Node::CommentStmt(_) => {},
186            Node::CreateExtensionStmt(ext) => {
187                if !ext.if_not_exists {
188                    errors.push(LintError {
189                        line,
190                        column: col,
191                        severity: LintSeverity::Warning,
192                        message: "CREATE EXTENSION without IF NOT EXISTS".into(),
193                        source: source.to_owned(),
194                    });
195                }
196            },
197            other => {
198                if let Some(reason) = imperative_reason(other) {
199                    errors.push(LintError {
200                        line,
201                        column: col,
202                        severity: LintSeverity::Error,
203                        message: format!(
204                            "imperative SQL in declarative schema: {reason} — move to \
205                             schema/migrations/NNN_<name>.sql"
206                        ),
207                        source: source.to_owned(),
208                    });
209                }
210            },
211        }
212    }
213
214    (tables, errors)
215}
216
217fn column_ref_pass(
218    stmts: &[pg_query::protobuf::RawStmt],
219    sql: &str,
220    line_index: &LineIndex,
221    tables: &[TableDef],
222    source: &str,
223) -> Vec<LintError> {
224    let mut errors: Vec<LintError> = Vec::new();
225
226    for raw in stmts {
227        let Some(stmt) = raw.stmt.as_ref() else {
228            continue;
229        };
230        let Some(node) = stmt.node.as_ref() else {
231            continue;
232        };
233        let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
234        let (line, col) = line_index.position(location);
235        let loc = StmtLoc { line, col, source };
236
237        match node {
238            Node::IndexStmt(idx) => {
239                check_index_columns(idx, tables, &loc, &mut errors);
240            },
241            Node::ViewStmt(view) => {
242                check_view_columns(view, tables, &loc, &mut errors);
243            },
244            _ => {},
245        }
246    }
247
248    errors
249}