Skip to main content

schema_risk/
parser.rs

1//! Thin wrapper around sqlparser-rs.
2//!
3//! Normalises the raw AST into our own `ParsedStatement` enum so the rest of
4//! the tool never needs to import sqlparser types directly.
5
6use crate::error::Result;
7use sqlparser::ast::{AlterTableOperation, ColumnOption, ObjectType, Statement, TableConstraint};
8use sqlparser::dialect::PostgreSqlDialect;
9use sqlparser::parser::Parser;
10
11// ─────────────────────────────────────────────
12// Public normalised representation
13// ─────────────────────────────────────────────
14
15#[derive(Debug, Clone)]
16pub struct ColumnInfo {
17    pub name: String,
18    pub data_type: String,
19    pub nullable: bool,
20    pub has_default: bool,
21    pub is_primary_key: bool,
22}
23
24#[derive(Debug, Clone)]
25pub struct ForeignKeyInfo {
26    pub columns: Vec<String>,
27    pub ref_table: String,
28    pub ref_columns: Vec<String>,
29    pub on_delete_cascade: bool,
30    pub on_update_cascade: bool,
31    pub constraint_name: Option<String>,
32}
33
34/// Normalised view of every SQL statement we care about.
35#[derive(Debug, Clone)]
36pub enum ParsedStatement {
37    CreateTable {
38        table: String,
39        columns: Vec<ColumnInfo>,
40        foreign_keys: Vec<ForeignKeyInfo>,
41        has_primary_key: bool,
42    },
43    DropTable {
44        tables: Vec<String>,
45        if_exists: bool,
46        cascade: bool,
47    },
48    AlterTableAddColumn {
49        table: String,
50        column: ColumnInfo,
51    },
52    AlterTableDropColumn {
53        table: String,
54        column: String,
55        if_exists: bool,
56    },
57    AlterTableAlterColumnType {
58        table: String,
59        column: String,
60        new_type: String,
61    },
62    AlterTableSetNotNull {
63        table: String,
64        column: String,
65    },
66    AlterTableAddForeignKey {
67        table: String,
68        fk: ForeignKeyInfo,
69    },
70    AlterTableDropConstraint {
71        table: String,
72        constraint: String,
73        cascade: bool,
74    },
75    AlterTableRenameColumn {
76        table: String,
77        old: String,
78        new: String,
79    },
80    AlterTableRenameTable {
81        old: String,
82        new: String,
83    },
84    CreateIndex {
85        index_name: Option<String>,
86        table: String,
87        columns: Vec<String>,
88        unique: bool,
89        concurrently: bool,
90    },
91    DropIndex {
92        names: Vec<String>,
93        concurrently: bool,
94        if_exists: bool,
95    },
96    AlterTableAddPrimaryKey {
97        table: String,
98        columns: Vec<String>,
99    },
100    AlterTableDropPrimaryKey {
101        table: String,
102    },
103    AlterTableAlterColumnDefault {
104        table: String,
105        column: String,
106        drop_default: bool,
107    },
108    /// REINDEX INDEX/TABLE/DATABASE — can take hours with ACCESS EXCLUSIVE (without CONCURRENTLY)
109    Reindex {
110        target_type: String, // "INDEX", "TABLE", "SCHEMA", "DATABASE"
111        target_name: String,
112        concurrently: bool,
113    },
114    /// CLUSTER table — full table rewrite with ACCESS EXCLUSIVE
115    Cluster {
116        table: Option<String>,
117        index: Option<String>,
118    },
119    /// TRUNCATE TABLE — instant but irreversible
120    Truncate {
121        tables: Vec<String>,
122        cascade: bool,
123    },
124    /// Catch-all for statements we don't inspect in detail.
125    Other {
126        raw: String,
127    },
128}
129
130// ─────────────────────────────────────────────
131// Belt-and-suspenders unsafe keyword list (B-01 fix)
132// ─────────────────────────────────────────────
133
134/// SQL keywords that are dangerous but may not be modelled by our ParsedStatement enum.
135/// Any segment containing one of these — even if sqlparser can't parse it — gets
136/// emitted as `Other` with the "Unmodelled DDL" note so the engine can score it.
137pub const UNSAFE_KEYWORDS: &[&str] = &[
138    "DROP TABLE",
139    "DROP DATABASE",
140    "DROP SCHEMA",
141    "TRUNCATE",
142    "ATTACH PARTITION",
143    "DETACH PARTITION",
144    "CREATE POLICY",
145    "ENABLE ROW LEVEL SECURITY",
146    "ALTER TABLE",
147    "REINDEX",      // P0: Major operation - can take hours with ACCESS EXCLUSIVE
148    "CLUSTER",      // Full table rewrite with ACCESS EXCLUSIVE
149    "VACUUM FULL",  // Full table rewrite with ACCESS EXCLUSIVE
150    "SET LOGGED",   // Table rewrite (UNLOGGED -> LOGGED)
151    "SET UNLOGGED", // Table rewrite (LOGGED -> UNLOGGED)
152];
153
154/// Returns `Some(note)` if `raw` (uppercased) contains any known unsafe keyword.
155pub fn check_unsafe_keywords(raw: &str) -> Option<String> {
156    let upper = raw.to_uppercase();
157    for kw in UNSAFE_KEYWORDS {
158        if upper.contains(kw) {
159            return Some(format!(
160                "Unmodelled DDL containing '{}' — manual review required",
161                kw
162            ));
163        }
164    }
165    None
166}
167
168// ─────────────────────────────────────────────
169// Parser
170// ─────────────────────────────────────────────
171
172/// Parse a full SQL string into a list of `ParsedStatement`s.
173///
174/// Fault-tolerant: unparseable statements (e.g. PL/pgSQL functions, custom
175/// extensions) are returned as `ParsedStatement::Other` rather than causing
176/// the whole file to fail.
177///
178/// **B-01 fix**: After sqlparser processing, a belt-and-suspenders sweep of the
179/// raw text checks for dangerous keywords that sqlparser might not have modelled
180/// (e.g. `ATTACH PARTITION`, `CREATE POLICY`).
181pub fn parse(sql: &str) -> Result<Vec<ParsedStatement>> {
182    let segments = split_into_segments(sql);
183    let dialect = PostgreSqlDialect {};
184    let mut results = Vec::new();
185
186    for seg in segments {
187        let trimmed = seg.trim();
188        if trimmed.is_empty() {
189            continue;
190        }
191        // Ensure the segment ends with a semicolon for sqlparser
192        let to_parse = if trimmed.ends_with(';') {
193            trimmed.to_string()
194        } else {
195            format!("{};", trimmed)
196        };
197
198        match Parser::parse_sql(&dialect, &to_parse) {
199            Ok(stmts) => {
200                for stmt in stmts {
201                    let parsed = lower_to_parsed(stmt);
202                    // B-01: if the statement came through as Other, check for unsafe keywords
203                    if let ParsedStatement::Other { ref raw } = parsed {
204                        if let Some(note) = check_unsafe_keywords(raw) {
205                            results.push(ParsedStatement::Other {
206                                raw: format!("{} [{}]", raw, note),
207                            });
208                            continue;
209                        }
210                    }
211                    results.push(parsed);
212                }
213            }
214            Err(_) => {
215                // B-01: belt-and-suspenders — check the raw segment for unsafe keywords
216                let raw_note = check_unsafe_keywords(trimmed)
217                    .map(|note| {
218                        format!(
219                            "{} [{}]",
220                            trimmed.chars().take(80).collect::<String>(),
221                            note
222                        )
223                    })
224                    .unwrap_or_else(|| trimmed.chars().take(80).collect());
225                results.push(ParsedStatement::Other { raw: raw_note });
226            }
227        }
228    }
229
230    Ok(results)
231}
232
233/// Split a SQL file into statement segments.
234///
235/// Strategy:
236/// 1. Respect dollar-quoted strings ($$…$$) — never split inside them.
237/// 2. Split on `;` outside dollar-quotes.
238/// 3. Also recognise statements that are separated only by blank lines
239///    (no trailing `;`) and normalise them.
240fn split_into_segments(sql: &str) -> Vec<String> {
241    let mut segments: Vec<String> = Vec::new();
242    let mut current = String::new();
243    let mut in_dollar_quote = false;
244    let mut dollar_tag = String::new();
245    let chars: Vec<char> = sql.chars().collect();
246    let len = chars.len();
247    let mut i = 0;
248
249    while i < len {
250        // Detect start/end of dollar-quoted string
251        if chars[i] == '$' {
252            // Try to read a dollar tag: $tag$ or $$
253            let mut j = i + 1;
254            while j < len && chars[j] != '$' && chars[j].is_alphanumeric()
255                || (j < len && chars[j] == '_')
256            {
257                j += 1;
258            }
259            if j < len && chars[j] == '$' {
260                let tag: String = chars[i..=j].iter().collect();
261                if !in_dollar_quote {
262                    in_dollar_quote = true;
263                    dollar_tag = tag.clone();
264                    current.push_str(&tag);
265                    i = j + 1;
266                    continue;
267                } else if tag == dollar_tag {
268                    in_dollar_quote = false;
269                    current.push_str(&tag);
270                    dollar_tag.clear();
271                    i = j + 1;
272                    continue;
273                }
274            }
275        }
276
277        if !in_dollar_quote && chars[i] == ';' {
278            current.push(';');
279            let seg = current.trim().to_string();
280            if !seg.is_empty() && seg != ";" {
281                segments.push(seg);
282            }
283            current.clear();
284            i += 1;
285            continue;
286        }
287
288        current.push(chars[i]);
289        i += 1;
290    }
291
292    // Handle any trailing content without a semicolon
293    let leftover = current.trim();
294    if !leftover.is_empty() {
295        // Split leftover by blank lines (no-semicolon style)
296        for block in leftover.split("\n\n") {
297            let b = block.trim();
298            if !b.is_empty() {
299                segments.push(b.to_string());
300            }
301        }
302    }
303
304    segments
305}
306
307// ─────────────────────────────────────────────
308// Normalisation helpers
309// ─────────────────────────────────────────────
310
311fn lower_to_parsed(stmt: Statement) -> ParsedStatement {
312    match stmt {
313        // ── CREATE TABLE ──────────────────────────────────────────────────
314        Statement::CreateTable(ct) => {
315            let table = ct.name.to_string();
316            let mut columns = Vec::new();
317            let mut foreign_keys = Vec::new();
318            let mut has_primary_key = false;
319
320            for col_def in &ct.columns {
321                let mut nullable = true;
322                let mut has_default = false;
323                let mut is_pk = false;
324
325                for opt in &col_def.options {
326                    match &opt.option {
327                        ColumnOption::NotNull => nullable = false,
328                        ColumnOption::Default(_) => has_default = true,
329                        ColumnOption::Unique { is_primary, .. } if *is_primary => {
330                            is_pk = true;
331                            has_primary_key = true;
332                            nullable = false;
333                        }
334                        ColumnOption::ForeignKey {
335                            foreign_table,
336                            referred_columns,
337                            on_delete,
338                            on_update,
339                            ..
340                        } => {
341                            foreign_keys.push(ForeignKeyInfo {
342                                columns: vec![col_def.name.to_string()],
343                                ref_table: foreign_table.to_string(),
344                                ref_columns: referred_columns
345                                    .iter()
346                                    .map(|c| c.to_string())
347                                    .collect(),
348                                on_delete_cascade: on_delete
349                                    .as_ref()
350                                    .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
351                                    .unwrap_or(false),
352                                on_update_cascade: on_update
353                                    .as_ref()
354                                    .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
355                                    .unwrap_or(false),
356                                constraint_name: None,
357                            });
358                        }
359                        _ => {}
360                    }
361                }
362
363                columns.push(ColumnInfo {
364                    name: col_def.name.to_string(),
365                    data_type: col_def.data_type.to_string(),
366                    nullable,
367                    has_default,
368                    is_primary_key: is_pk,
369                });
370            }
371
372            // Table-level constraints
373            for constraint in &ct.constraints {
374                match constraint {
375                    TableConstraint::ForeignKey {
376                        name,
377                        columns: fk_cols,
378                        foreign_table,
379                        referred_columns,
380                        on_delete,
381                        on_update,
382                        ..
383                    } => {
384                        foreign_keys.push(ForeignKeyInfo {
385                            columns: fk_cols.iter().map(|c| c.to_string()).collect(),
386                            ref_table: foreign_table.to_string(),
387                            ref_columns: referred_columns.iter().map(|c| c.to_string()).collect(),
388                            on_delete_cascade: on_delete
389                                .as_ref()
390                                .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
391                                .unwrap_or(false),
392                            on_update_cascade: on_update
393                                .as_ref()
394                                .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
395                                .unwrap_or(false),
396                            constraint_name: name.as_ref().map(|n| n.to_string()),
397                        });
398                    }
399                    TableConstraint::PrimaryKey { .. } | TableConstraint::Unique { .. } => {
400                        has_primary_key = true;
401                    }
402                    _ => {}
403                }
404            }
405
406            ParsedStatement::CreateTable {
407                table,
408                columns,
409                foreign_keys,
410                has_primary_key,
411            }
412        }
413
414        // ── DROP TABLE ────────────────────────────────────────────────────
415        Statement::Drop {
416            object_type: ObjectType::Table,
417            names,
418            if_exists,
419            cascade,
420            ..
421        } => ParsedStatement::DropTable {
422            tables: names.iter().map(|n| n.to_string()).collect(),
423            if_exists,
424            cascade,
425        },
426
427        // ── DROP INDEX ────────────────────────────────────────────────────
428        Statement::Drop {
429            object_type: ObjectType::Index,
430            names,
431            if_exists,
432            ..
433        } => {
434            let raw = names
435                .iter()
436                .map(|n| n.to_string())
437                .collect::<Vec<_>>()
438                .join(", ");
439            // Check if CONCURRENTLY keyword appears; sqlparser puts it in name
440            let concurrently = raw.to_uppercase().contains("CONCURRENTLY");
441            ParsedStatement::DropIndex {
442                names: names.iter().map(|n| n.to_string()).collect(),
443                concurrently,
444                if_exists,
445            }
446        }
447
448        // ── CREATE INDEX ──────────────────────────────────────────────────
449        Statement::CreateIndex(ci) => {
450            let table = ci.table_name.to_string();
451            let columns = ci.columns.iter().map(|c| c.expr.to_string()).collect();
452            ParsedStatement::CreateIndex {
453                index_name: ci.name.as_ref().map(|n| n.to_string()),
454                table,
455                columns,
456                unique: ci.unique,
457                concurrently: ci.concurrently,
458            }
459        }
460
461        // ── ALTER TABLE ───────────────────────────────────────────────────
462        Statement::AlterTable {
463            name, operations, ..
464        } => {
465            let table = name.to_string();
466
467            // We handle the first meaningful operation; one ALTER TABLE
468            // per statement is the common case.
469            for op in &operations {
470                match op {
471                    // ADD COLUMN
472                    AlterTableOperation::AddColumn { column_def, .. } => {
473                        let mut nullable = true;
474                        let mut has_default = false;
475                        let mut is_pk = false;
476
477                        for opt in &column_def.options {
478                            match &opt.option {
479                                ColumnOption::NotNull => nullable = false,
480                                ColumnOption::Default(_) => has_default = true,
481                                ColumnOption::Unique { is_primary, .. } if *is_primary => {
482                                    is_pk = true;
483                                }
484                                _ => {}
485                            }
486                        }
487
488                        return ParsedStatement::AlterTableAddColumn {
489                            table,
490                            column: ColumnInfo {
491                                name: column_def.name.to_string(),
492                                data_type: column_def.data_type.to_string(),
493                                nullable,
494                                has_default,
495                                is_primary_key: is_pk,
496                            },
497                        };
498                    }
499
500                    // DROP COLUMN
501                    AlterTableOperation::DropColumn {
502                        column_name,
503                        if_exists,
504                        ..
505                    } => {
506                        return ParsedStatement::AlterTableDropColumn {
507                            table,
508                            column: column_name.to_string(),
509                            if_exists: *if_exists,
510                        };
511                    }
512
513                    // ALTER COLUMN TYPE
514                    AlterTableOperation::AlterColumn { column_name, op } => {
515                        use sqlparser::ast::AlterColumnOperation;
516                        match op {
517                            AlterColumnOperation::SetDataType { data_type, .. } => {
518                                return ParsedStatement::AlterTableAlterColumnType {
519                                    table,
520                                    column: column_name.to_string(),
521                                    new_type: data_type.to_string(),
522                                };
523                            }
524                            AlterColumnOperation::SetNotNull => {
525                                return ParsedStatement::AlterTableSetNotNull {
526                                    table,
527                                    column: column_name.to_string(),
528                                };
529                            }
530                            AlterColumnOperation::DropDefault => {
531                                return ParsedStatement::AlterTableAlterColumnDefault {
532                                    table,
533                                    column: column_name.to_string(),
534                                    drop_default: true,
535                                };
536                            }
537                            AlterColumnOperation::SetDefault { .. } => {
538                                return ParsedStatement::AlterTableAlterColumnDefault {
539                                    table,
540                                    column: column_name.to_string(),
541                                    drop_default: false,
542                                };
543                            }
544                            _ => {}
545                        }
546                    }
547
548                    // ADD CONSTRAINT (FK, PK, unique)
549                    AlterTableOperation::AddConstraint(constraint) => match constraint {
550                        TableConstraint::ForeignKey {
551                            name,
552                            columns: fk_cols,
553                            foreign_table,
554                            referred_columns,
555                            on_delete,
556                            on_update,
557                            ..
558                        } => {
559                            return ParsedStatement::AlterTableAddForeignKey {
560                                table,
561                                fk: ForeignKeyInfo {
562                                    columns: fk_cols.iter().map(|c| c.to_string()).collect(),
563                                    ref_table: foreign_table.to_string(),
564                                    ref_columns: referred_columns
565                                        .iter()
566                                        .map(|c| c.to_string())
567                                        .collect(),
568                                    on_delete_cascade: on_delete
569                                        .as_ref()
570                                        .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
571                                        .unwrap_or(false),
572                                    on_update_cascade: on_update
573                                        .as_ref()
574                                        .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
575                                        .unwrap_or(false),
576                                    constraint_name: name.as_ref().map(|n| n.to_string()),
577                                },
578                            };
579                        }
580                        TableConstraint::PrimaryKey { columns, .. } => {
581                            return ParsedStatement::AlterTableAddPrimaryKey {
582                                table,
583                                columns: columns.iter().map(|c| c.to_string()).collect(),
584                            };
585                        }
586                        _ => {}
587                    },
588
589                    // DROP CONSTRAINT
590                    AlterTableOperation::DropConstraint { name, cascade, .. } => {
591                        return ParsedStatement::AlterTableDropConstraint {
592                            table,
593                            constraint: name.to_string(),
594                            cascade: *cascade,
595                        };
596                    }
597
598                    // RENAME COLUMN
599                    AlterTableOperation::RenameColumn {
600                        old_column_name,
601                        new_column_name,
602                    } => {
603                        return ParsedStatement::AlterTableRenameColumn {
604                            table,
605                            old: old_column_name.to_string(),
606                            new: new_column_name.to_string(),
607                        };
608                    }
609
610                    // RENAME TABLE
611                    AlterTableOperation::RenameTable { table_name } => {
612                        return ParsedStatement::AlterTableRenameTable {
613                            old: table,
614                            new: table_name.to_string(),
615                        };
616                    }
617
618                    _ => {}
619                }
620            }
621
622            // Unrecognised ALTER TABLE operation
623            ParsedStatement::Other {
624                raw: format!("ALTER TABLE {}", name),
625            }
626        }
627
628        // ── TRUNCATE ─────────────────────────────────────────────────────
629        Statement::Truncate {
630            table_names,
631            cascade,
632            ..
633        } => {
634            // cascade is Option<TruncateCascadeOption> which can be Cascade or Restrict
635            let is_cascade = cascade
636                .as_ref()
637                .map(|c| matches!(c, sqlparser::ast::TruncateCascadeOption::Cascade))
638                .unwrap_or(false);
639            ParsedStatement::Truncate {
640                tables: table_names.iter().map(|t| t.name.to_string()).collect(),
641                cascade: is_cascade,
642            }
643        }
644
645        other => ParsedStatement::Other {
646            raw: other.to_string().chars().take(80).collect(),
647        },
648    }
649}