Skip to main content

systemprompt_database/services/schema_linter/
mod.rs

1//! Declarative-schema linter.
2//!
3//! Parses each schema with [`pg_query`] (the actual `PostgreSQL` parser,
4//! exposed as a protobuf AST) and walks top-level statements. Classification is
5//! by AST node variant rather than keyword tokens, so identifier-equal strings
6//! such as a column literally named `alter` do not produce false positives,
7//! and dollar-quoted PL/pgSQL bodies are skipped at the parser level.
8//!
9//! ## Allowed top-level statements
10//!
11//! - `CreateStmt` — `CREATE TABLE`
12//! - `IndexStmt` — `CREATE [UNIQUE] INDEX`
13//! - `CreateFunctionStmt`
14//! - `ViewStmt` — `CREATE [OR REPLACE] VIEW`
15//! - `CreateTrigStmt`
16//! - `CompositeTypeStmt` — `CREATE TYPE … AS (…)`
17//! - `CreateEnumStmt` — `CREATE TYPE … AS ENUM`
18//! - `CreateExtensionStmt`
19//! - `CommentStmt` — `COMMENT ON …`
20//! - `DropStmt` — only `DROP VIEW`/`MATERIALIZED VIEW`/`INDEX`/`TRIGGER … IF
21//!   EXISTS`. These objects are stateless derived artifacts: dropping one loses
22//!   no data and the sibling `CREATE …` statement rebuilds it, so the pair
23//!   stays idempotent. `DROP TABLE`/`DROP COLUMN` remain rejected.
24//!
25//! ## Rejected top-level statements
26//!
27//! - `AlterTableStmt`
28//! - `DropStmt` — except the stateless-object carve-out above
29//! - `InsertStmt` / `UpdateStmt` / `DeleteStmt` / `TruncateStmt`
30//! - `GrantStmt` / `RevokeStmt`
31//! - `RenameStmt` — any object rename
32//! - `DoStmt` — anonymous `DO $$ … $$` blocks
33//! - Any bare `SELECT`/`COPY`/imperative statement
34//!
35//! ## Semantic checks
36//!
37//! For statements that reference columns of a table defined elsewhere in the
38//! same input (`CREATE INDEX`, `CREATE VIEW`), the linter resolves the
39//! `(table, column)` pair against an in-input schema graph built from sibling
40//! `CREATE TABLE` nodes. References to tables that are not declared in the
41//! same input (e.g. cross-extension `REFERENCES`) are intentionally not
42//! resolved — the parser sees those as forward references the database itself
43//! validates at apply-time.
44//!
45//! Column resolution does not descend into:
46//!
47//! - PL/pgSQL function bodies (resolved by Postgres at function call time)
48//! - `CHECK` constraint expressions (resolved by Postgres at table creation)
49//! - Trigger function bodies
50//!
51//! These are deferred so the linter behaves identically to the database for
52//! anything it cannot statically prove, avoiding false positives on
53//! late-bound names.
54
55mod classify;
56mod columns;
57mod location;
58
59use std::fmt;
60
61use pg_query::protobuf::node::Node;
62
63use classify::{imperative_reason, warn_create_table_missing_if_not_exists};
64use columns::{TableDef, check_index_columns, check_view_columns, collect_create_stmt};
65use location::{LineIndex, StmtLoc, stmt_start_offset};
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LintSeverity {
69    Error,
70    Warning,
71}
72
73impl fmt::Display for LintSeverity {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        match self {
76            Self::Error => f.write_str("error"),
77            Self::Warning => f.write_str("warning"),
78        }
79    }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct LintError {
84    pub line: u32,
85    pub column: u32,
86    pub severity: LintSeverity,
87    pub message: String,
88    pub source: String,
89}
90
91impl fmt::Display for LintError {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(
94            f,
95            "{}:{}:{}: {}: {}",
96            self.source, self.line, self.column, self.severity, self.message
97        )
98    }
99}
100
101/// Names of every table created by a `CREATE TABLE` in `sql`.
102///
103/// This is the single source of truth for which tables an extension *owns*:
104/// ownership is derived from its declarative schema, never hand-authored. A
105/// parse failure yields an empty list — the linter reports the parse error
106/// separately.
107#[must_use]
108pub fn created_table_names(sql: &str) -> Vec<String> {
109    let Ok(parsed) = pg_query::parse(sql) else {
110        return Vec::new();
111    };
112    parsed
113        .protobuf
114        .stmts
115        .iter()
116        .filter_map(|raw| match raw.stmt.as_ref()?.node.as_ref()? {
117            Node::CreateStmt(create) => collect_create_stmt(create).map(|t| t.name().to_string()),
118            _ => None,
119        })
120        .collect()
121}
122
123/// Lint a single declarative schema file. Returns the list of violations,
124/// or `Ok(())` if the script is purely declarative.
125///
126/// `source` is the label included in error messages (typically the schema
127/// table name or the file path).
128pub fn lint_declarative_schema(sql: &str, source: &str) -> Result<(), Vec<LintError>> {
129    let parsed = match pg_query::parse(sql) {
130        Ok(p) => p,
131        Err(e) => {
132            return Err(vec![LintError {
133                line: 1,
134                column: 1,
135                severity: LintSeverity::Error,
136                message: format!("SQL parse failed: {e}"),
137                source: source.to_string(),
138            }]);
139        },
140    };
141
142    let line_index = LineIndex::new(sql);
143    let stmts = &parsed.protobuf.stmts;
144    let (tables, mut errors) = classify_pass(stmts, sql, &line_index, source);
145    errors.extend(column_ref_pass(stmts, sql, &line_index, &tables, source));
146
147    if errors.iter().any(|e| e.severity == LintSeverity::Error) {
148        return Err(errors);
149    }
150    Ok(())
151}
152
153/// First pass: classify every top-level statement as allowed or rejected and
154/// collect the in-input `CREATE TABLE` graph for the column-resolution pass.
155fn classify_pass(
156    stmts: &[pg_query::protobuf::RawStmt],
157    sql: &str,
158    line_index: &LineIndex,
159    source: &str,
160) -> (Vec<TableDef>, Vec<LintError>) {
161    let mut errors: Vec<LintError> = Vec::new();
162    let mut tables: Vec<TableDef> = Vec::new();
163
164    for raw in stmts {
165        let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
166        let (line, col) = line_index.position(location);
167        let loc = StmtLoc { line, col, source };
168
169        let Some(stmt) = raw.stmt.as_ref() else {
170            continue;
171        };
172        let Some(node) = stmt.node.as_ref() else {
173            continue;
174        };
175
176        match node {
177            Node::CreateStmt(create) => {
178                if let Some(table) = collect_create_stmt(create) {
179                    tables.push(table);
180                }
181                if let Some(warn) = warn_create_table_missing_if_not_exists(create, &loc) {
182                    errors.push(warn);
183                }
184            },
185            Node::IndexStmt(_)
186            | Node::CreateFunctionStmt(_)
187            | Node::ViewStmt(_)
188            | Node::CreateTrigStmt(_)
189            | Node::CompositeTypeStmt(_)
190            | Node::CreateEnumStmt(_)
191            | Node::CommentStmt(_) => {},
192            Node::CreateExtensionStmt(ext) => {
193                if !ext.if_not_exists {
194                    errors.push(LintError {
195                        line,
196                        column: col,
197                        severity: LintSeverity::Warning,
198                        message: "CREATE EXTENSION without IF NOT EXISTS".into(),
199                        source: source.to_string(),
200                    });
201                }
202            },
203            other => {
204                if let Some(reason) = imperative_reason(other) {
205                    errors.push(LintError {
206                        line,
207                        column: col,
208                        severity: LintSeverity::Error,
209                        message: format!(
210                            "imperative SQL in declarative schema: {reason} — move to \
211                             schema/migrations/NNN_<name>.sql"
212                        ),
213                        source: source.to_string(),
214                    });
215                }
216            },
217        }
218    }
219
220    (tables, errors)
221}
222
223/// Second pass: resolve `(table, column)` references in `CREATE INDEX` and
224/// `CREATE VIEW` statements against the table graph from [`classify_pass`].
225fn column_ref_pass(
226    stmts: &[pg_query::protobuf::RawStmt],
227    sql: &str,
228    line_index: &LineIndex,
229    tables: &[TableDef],
230    source: &str,
231) -> Vec<LintError> {
232    let mut errors: Vec<LintError> = Vec::new();
233
234    for raw in stmts {
235        let Some(stmt) = raw.stmt.as_ref() else {
236            continue;
237        };
238        let Some(node) = stmt.node.as_ref() else {
239            continue;
240        };
241        let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
242        let (line, col) = line_index.position(location);
243        let loc = StmtLoc { line, col, source };
244
245        match node {
246            Node::IndexStmt(idx) => {
247                check_index_columns(idx, tables, &loc, &mut errors);
248            },
249            Node::ViewStmt(view) => {
250                check_view_columns(view, tables, &loc, &mut errors);
251            },
252            _ => {},
253        }
254    }
255
256    errors
257}