systemprompt_database/services/schema_linter/
mod.rs1mod classify;
56mod columns;
57mod location;
58
59use std::fmt;
60
61use pg_query::protobuf::node::Node;
62
63use classify::{imperative_reason, warn_create_table_missing_if_not_exists};
64use columns::{TableDef, check_index_columns, check_view_columns, collect_create_stmt};
65use location::{LineIndex, StmtLoc, stmt_start_offset};
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LintSeverity {
69 Error,
70 Warning,
71}
72
73impl fmt::Display for LintSeverity {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 Self::Error => f.write_str("error"),
77 Self::Warning => f.write_str("warning"),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct LintError {
84 pub line: u32,
85 pub column: u32,
86 pub severity: LintSeverity,
87 pub message: String,
88 pub source: String,
89}
90
91impl fmt::Display for LintError {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 write!(
94 f,
95 "{}:{}:{}: {}: {}",
96 self.source, self.line, self.column, self.severity, self.message
97 )
98 }
99}
100
101#[must_use]
107pub fn created_table_names(sql: &str) -> Vec<String> {
108 let Ok(parsed) = pg_query::parse(sql) else {
109 return Vec::new();
110 };
111 parsed
112 .protobuf
113 .stmts
114 .iter()
115 .filter_map(|raw| match raw.stmt.as_ref()?.node.as_ref()? {
116 Node::CreateStmt(create) => collect_create_stmt(create).map(|t| t.name().to_owned()),
117 _ => None,
118 })
119 .collect()
120}
121
122pub fn lint_declarative_schema(sql: &str, source: &str) -> Result<(), Vec<LintError>> {
125 let parsed = match pg_query::parse(sql) {
126 Ok(p) => p,
127 Err(e) => {
128 return Err(vec![LintError {
129 line: 1,
130 column: 1,
131 severity: LintSeverity::Error,
132 message: format!("SQL parse failed: {e}"),
133 source: source.to_owned(),
134 }]);
135 },
136 };
137
138 let line_index = LineIndex::new(sql);
139 let stmts = &parsed.protobuf.stmts;
140 let (tables, mut errors) = classify_pass(stmts, sql, &line_index, source);
141 errors.extend(column_ref_pass(stmts, sql, &line_index, &tables, source));
142
143 if errors.iter().any(|e| e.severity == LintSeverity::Error) {
144 return Err(errors);
145 }
146 Ok(())
147}
148
149fn classify_pass(
150 stmts: &[pg_query::protobuf::RawStmt],
151 sql: &str,
152 line_index: &LineIndex,
153 source: &str,
154) -> (Vec<TableDef>, Vec<LintError>) {
155 let mut errors: Vec<LintError> = Vec::new();
156 let mut tables: Vec<TableDef> = Vec::new();
157
158 for raw in stmts {
159 let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
160 let (line, col) = line_index.position(location);
161 let loc = StmtLoc { line, col, source };
162
163 let Some(stmt) = raw.stmt.as_ref() else {
164 continue;
165 };
166 let Some(node) = stmt.node.as_ref() else {
167 continue;
168 };
169
170 match node {
171 Node::CreateStmt(create) => {
172 if let Some(table) = collect_create_stmt(create) {
173 tables.push(table);
174 }
175 if let Some(warn) = warn_create_table_missing_if_not_exists(create, &loc) {
176 errors.push(warn);
177 }
178 },
179 Node::IndexStmt(_)
180 | Node::CreateFunctionStmt(_)
181 | Node::ViewStmt(_)
182 | Node::CreateTrigStmt(_)
183 | Node::CompositeTypeStmt(_)
184 | Node::CreateEnumStmt(_)
185 | Node::CommentStmt(_) => {},
186 Node::CreateExtensionStmt(ext) => {
187 if !ext.if_not_exists {
188 errors.push(LintError {
189 line,
190 column: col,
191 severity: LintSeverity::Warning,
192 message: "CREATE EXTENSION without IF NOT EXISTS".into(),
193 source: source.to_owned(),
194 });
195 }
196 },
197 other => {
198 if let Some(reason) = imperative_reason(other) {
199 errors.push(LintError {
200 line,
201 column: col,
202 severity: LintSeverity::Error,
203 message: format!(
204 "imperative SQL in declarative schema: {reason} — move to \
205 schema/migrations/NNN_<name>.sql"
206 ),
207 source: source.to_owned(),
208 });
209 }
210 },
211 }
212 }
213
214 (tables, errors)
215}
216
217fn column_ref_pass(
218 stmts: &[pg_query::protobuf::RawStmt],
219 sql: &str,
220 line_index: &LineIndex,
221 tables: &[TableDef],
222 source: &str,
223) -> Vec<LintError> {
224 let mut errors: Vec<LintError> = Vec::new();
225
226 for raw in stmts {
227 let Some(stmt) = raw.stmt.as_ref() else {
228 continue;
229 };
230 let Some(node) = stmt.node.as_ref() else {
231 continue;
232 };
233 let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
234 let (line, col) = line_index.position(location);
235 let loc = StmtLoc { line, col, source };
236
237 match node {
238 Node::IndexStmt(idx) => {
239 check_index_columns(idx, tables, &loc, &mut errors);
240 },
241 Node::ViewStmt(view) => {
242 check_view_columns(view, tables, &loc, &mut errors);
243 },
244 _ => {},
245 }
246 }
247
248 errors
249}