Skip to main content

schema_risk/
parser.rs

1//! Thin wrapper around sqlparser-rs.
2//!
3//! Normalises the raw AST into our own `ParsedStatement` enum so the rest of
4//! the tool never needs to import sqlparser types directly.
5
6use crate::error::Result;
7use sqlparser::ast::{AlterTableOperation, ColumnOption, ObjectType, Statement, TableConstraint};
8use sqlparser::dialect::PostgreSqlDialect;
9use sqlparser::parser::Parser;
10
11// ─────────────────────────────────────────────
12// Public normalised representation
13// ─────────────────────────────────────────────
14
15#[derive(Debug, Clone)]
16pub struct ColumnInfo {
17    pub name: String,
18    pub data_type: String,
19    pub nullable: bool,
20    pub has_default: bool,
21    pub is_primary_key: bool,
22}
23
24#[derive(Debug, Clone)]
25pub struct ForeignKeyInfo {
26    pub columns: Vec<String>,
27    pub ref_table: String,
28    pub ref_columns: Vec<String>,
29    pub on_delete_cascade: bool,
30    pub on_update_cascade: bool,
31    pub constraint_name: Option<String>,
32}
33
34/// Normalised view of every SQL statement we care about.
35#[derive(Debug, Clone)]
36pub enum ParsedStatement {
37    CreateTable {
38        table: String,
39        columns: Vec<ColumnInfo>,
40        foreign_keys: Vec<ForeignKeyInfo>,
41        has_primary_key: bool,
42    },
43    DropTable {
44        tables: Vec<String>,
45        if_exists: bool,
46        cascade: bool,
47    },
48    AlterTableAddColumn {
49        table: String,
50        column: ColumnInfo,
51    },
52    AlterTableDropColumn {
53        table: String,
54        column: String,
55        if_exists: bool,
56    },
57    AlterTableAlterColumnType {
58        table: String,
59        column: String,
60        new_type: String,
61    },
62    AlterTableSetNotNull {
63        table: String,
64        column: String,
65    },
66    AlterTableAddForeignKey {
67        table: String,
68        fk: ForeignKeyInfo,
69    },
70    AlterTableDropConstraint {
71        table: String,
72        constraint: String,
73        cascade: bool,
74    },
75    AlterTableRenameColumn {
76        table: String,
77        old: String,
78        new: String,
79    },
80    AlterTableRenameTable {
81        old: String,
82        new: String,
83    },
84    CreateIndex {
85        index_name: Option<String>,
86        table: String,
87        columns: Vec<String>,
88        unique: bool,
89        concurrently: bool,
90    },
91    DropIndex {
92        names: Vec<String>,
93        concurrently: bool,
94        if_exists: bool,
95    },
96    AlterTableAddPrimaryKey {
97        table: String,
98        columns: Vec<String>,
99    },
100    AlterTableDropPrimaryKey {
101        table: String,
102    },
103    AlterTableAlterColumnDefault {
104        table: String,
105        column: String,
106        drop_default: bool,
107    },
108    /// Catch-all for statements we don't inspect in detail.
109    Other {
110        raw: String,
111    },
112}
113
114// ─────────────────────────────────────────────
115// Belt-and-suspenders unsafe keyword list (B-01 fix)
116// ─────────────────────────────────────────────
117
118/// SQL keywords that are dangerous but may not be modelled by our ParsedStatement enum.
119/// Any segment containing one of these — even if sqlparser can't parse it — gets
120/// emitted as `Other` with the "Unmodelled DDL" note so the engine can score it.
121pub const UNSAFE_KEYWORDS: &[&str] = &[
122    "DROP TABLE",
123    "DROP DATABASE",
124    "DROP SCHEMA",
125    "TRUNCATE",
126    "ATTACH PARTITION",
127    "CREATE POLICY",
128    "ENABLE ROW LEVEL SECURITY",
129    "ALTER TABLE",
130];
131
132/// Returns `Some(note)` if `raw` (uppercased) contains any known unsafe keyword.
133pub fn check_unsafe_keywords(raw: &str) -> Option<String> {
134    let upper = raw.to_uppercase();
135    for kw in UNSAFE_KEYWORDS {
136        if upper.contains(kw) {
137            return Some(format!(
138                "Unmodelled DDL containing '{}' — manual review required",
139                kw
140            ));
141        }
142    }
143    None
144}
145
146// ─────────────────────────────────────────────
147// Parser
148// ─────────────────────────────────────────────
149
150/// Parse a full SQL string into a list of `ParsedStatement`s.
151///
152/// Fault-tolerant: unparseable statements (e.g. PL/pgSQL functions, custom
153/// extensions) are returned as `ParsedStatement::Other` rather than causing
154/// the whole file to fail.
155///
156/// **B-01 fix**: After sqlparser processing, a belt-and-suspenders sweep of the
157/// raw text checks for dangerous keywords that sqlparser might not have modelled
158/// (e.g. `ATTACH PARTITION`, `CREATE POLICY`).
159pub fn parse(sql: &str) -> Result<Vec<ParsedStatement>> {
160    let segments = split_into_segments(sql);
161    let dialect = PostgreSqlDialect {};
162    let mut results = Vec::new();
163
164    for seg in segments {
165        let trimmed = seg.trim();
166        if trimmed.is_empty() {
167            continue;
168        }
169        // Ensure the segment ends with a semicolon for sqlparser
170        let to_parse = if trimmed.ends_with(';') {
171            trimmed.to_string()
172        } else {
173            format!("{};", trimmed)
174        };
175
176        match Parser::parse_sql(&dialect, &to_parse) {
177            Ok(stmts) => {
178                for stmt in stmts {
179                    let parsed = lower_to_parsed(stmt);
180                    // B-01: if the statement came through as Other, check for unsafe keywords
181                    if let ParsedStatement::Other { ref raw } = parsed {
182                        if let Some(note) = check_unsafe_keywords(raw) {
183                            results.push(ParsedStatement::Other {
184                                raw: format!("{} [{}]", raw, note),
185                            });
186                            continue;
187                        }
188                    }
189                    results.push(parsed);
190                }
191            }
192            Err(_) => {
193                // B-01: belt-and-suspenders — check the raw segment for unsafe keywords
194                let raw_note = check_unsafe_keywords(trimmed)
195                    .map(|note| {
196                        format!(
197                            "{} [{}]",
198                            trimmed.chars().take(80).collect::<String>(),
199                            note
200                        )
201                    })
202                    .unwrap_or_else(|| trimmed.chars().take(80).collect());
203                results.push(ParsedStatement::Other { raw: raw_note });
204            }
205        }
206    }
207
208    Ok(results)
209}
210
211/// Split a SQL file into statement segments.
212///
213/// Strategy:
214/// 1. Respect dollar-quoted strings ($$…$$) — never split inside them.
215/// 2. Split on `;` outside dollar-quotes.
216/// 3. Also recognise statements that are separated only by blank lines
217///    (no trailing `;`) and normalise them.
218fn split_into_segments(sql: &str) -> Vec<String> {
219    let mut segments: Vec<String> = Vec::new();
220    let mut current = String::new();
221    let mut in_dollar_quote = false;
222    let mut dollar_tag = String::new();
223    let chars: Vec<char> = sql.chars().collect();
224    let len = chars.len();
225    let mut i = 0;
226
227    while i < len {
228        // Detect start/end of dollar-quoted string
229        if chars[i] == '$' {
230            // Try to read a dollar tag: $tag$ or $$
231            let mut j = i + 1;
232            while j < len && chars[j] != '$' && chars[j].is_alphanumeric()
233                || (j < len && chars[j] == '_')
234            {
235                j += 1;
236            }
237            if j < len && chars[j] == '$' {
238                let tag: String = chars[i..=j].iter().collect();
239                if !in_dollar_quote {
240                    in_dollar_quote = true;
241                    dollar_tag = tag.clone();
242                    current.push_str(&tag);
243                    i = j + 1;
244                    continue;
245                } else if tag == dollar_tag {
246                    in_dollar_quote = false;
247                    current.push_str(&tag);
248                    dollar_tag.clear();
249                    i = j + 1;
250                    continue;
251                }
252            }
253        }
254
255        if !in_dollar_quote && chars[i] == ';' {
256            current.push(';');
257            let seg = current.trim().to_string();
258            if !seg.is_empty() && seg != ";" {
259                segments.push(seg);
260            }
261            current.clear();
262            i += 1;
263            continue;
264        }
265
266        current.push(chars[i]);
267        i += 1;
268    }
269
270    // Handle any trailing content without a semicolon
271    let leftover = current.trim();
272    if !leftover.is_empty() {
273        // Split leftover by blank lines (no-semicolon style)
274        for block in leftover.split("\n\n") {
275            let b = block.trim();
276            if !b.is_empty() {
277                segments.push(b.to_string());
278            }
279        }
280    }
281
282    segments
283}
284
285// ─────────────────────────────────────────────
286// Normalisation helpers
287// ─────────────────────────────────────────────
288
289fn lower_to_parsed(stmt: Statement) -> ParsedStatement {
290    match stmt {
291        // ── CREATE TABLE ──────────────────────────────────────────────────
292        Statement::CreateTable(ct) => {
293            let table = ct.name.to_string();
294            let mut columns = Vec::new();
295            let mut foreign_keys = Vec::new();
296            let mut has_primary_key = false;
297
298            for col_def in &ct.columns {
299                let mut nullable = true;
300                let mut has_default = false;
301                let mut is_pk = false;
302
303                for opt in &col_def.options {
304                    match &opt.option {
305                        ColumnOption::NotNull => nullable = false,
306                        ColumnOption::Default(_) => has_default = true,
307                        ColumnOption::Unique { is_primary, .. } if *is_primary => {
308                            is_pk = true;
309                            has_primary_key = true;
310                            nullable = false;
311                        }
312                        ColumnOption::ForeignKey {
313                            foreign_table,
314                            referred_columns,
315                            on_delete,
316                            on_update,
317                            ..
318                        } => {
319                            foreign_keys.push(ForeignKeyInfo {
320                                columns: vec![col_def.name.to_string()],
321                                ref_table: foreign_table.to_string(),
322                                ref_columns: referred_columns
323                                    .iter()
324                                    .map(|c| c.to_string())
325                                    .collect(),
326                                on_delete_cascade: on_delete
327                                    .as_ref()
328                                    .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
329                                    .unwrap_or(false),
330                                on_update_cascade: on_update
331                                    .as_ref()
332                                    .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
333                                    .unwrap_or(false),
334                                constraint_name: None,
335                            });
336                        }
337                        _ => {}
338                    }
339                }
340
341                columns.push(ColumnInfo {
342                    name: col_def.name.to_string(),
343                    data_type: col_def.data_type.to_string(),
344                    nullable,
345                    has_default,
346                    is_primary_key: is_pk,
347                });
348            }
349
350            // Table-level constraints
351            for constraint in &ct.constraints {
352                match constraint {
353                    TableConstraint::ForeignKey {
354                        name,
355                        columns: fk_cols,
356                        foreign_table,
357                        referred_columns,
358                        on_delete,
359                        on_update,
360                        ..
361                    } => {
362                        foreign_keys.push(ForeignKeyInfo {
363                            columns: fk_cols.iter().map(|c| c.to_string()).collect(),
364                            ref_table: foreign_table.to_string(),
365                            ref_columns: referred_columns.iter().map(|c| c.to_string()).collect(),
366                            on_delete_cascade: on_delete
367                                .as_ref()
368                                .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
369                                .unwrap_or(false),
370                            on_update_cascade: on_update
371                                .as_ref()
372                                .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
373                                .unwrap_or(false),
374                            constraint_name: name.as_ref().map(|n| n.to_string()),
375                        });
376                    }
377                    TableConstraint::PrimaryKey { .. } | TableConstraint::Unique { .. } => {
378                        has_primary_key = true;
379                    }
380                    _ => {}
381                }
382            }
383
384            ParsedStatement::CreateTable {
385                table,
386                columns,
387                foreign_keys,
388                has_primary_key,
389            }
390        }
391
392        // ── DROP TABLE ────────────────────────────────────────────────────
393        Statement::Drop {
394            object_type: ObjectType::Table,
395            names,
396            if_exists,
397            cascade,
398            ..
399        } => ParsedStatement::DropTable {
400            tables: names.iter().map(|n| n.to_string()).collect(),
401            if_exists,
402            cascade,
403        },
404
405        // ── DROP INDEX ────────────────────────────────────────────────────
406        Statement::Drop {
407            object_type: ObjectType::Index,
408            names,
409            if_exists,
410            ..
411        } => {
412            let raw = names
413                .iter()
414                .map(|n| n.to_string())
415                .collect::<Vec<_>>()
416                .join(", ");
417            // Check if CONCURRENTLY keyword appears; sqlparser puts it in name
418            let concurrently = raw.to_uppercase().contains("CONCURRENTLY");
419            ParsedStatement::DropIndex {
420                names: names.iter().map(|n| n.to_string()).collect(),
421                concurrently,
422                if_exists,
423            }
424        }
425
426        // ── CREATE INDEX ──────────────────────────────────────────────────
427        Statement::CreateIndex(ci) => {
428            let table = ci.table_name.to_string();
429            let columns = ci.columns.iter().map(|c| c.expr.to_string()).collect();
430            ParsedStatement::CreateIndex {
431                index_name: ci.name.as_ref().map(|n| n.to_string()),
432                table,
433                columns,
434                unique: ci.unique,
435                concurrently: ci.concurrently,
436            }
437        }
438
439        // ── ALTER TABLE ───────────────────────────────────────────────────
440        Statement::AlterTable {
441            name, operations, ..
442        } => {
443            let table = name.to_string();
444
445            // We handle the first meaningful operation; one ALTER TABLE
446            // per statement is the common case.
447            for op in &operations {
448                match op {
449                    // ADD COLUMN
450                    AlterTableOperation::AddColumn { column_def, .. } => {
451                        let mut nullable = true;
452                        let mut has_default = false;
453                        let mut is_pk = false;
454
455                        for opt in &column_def.options {
456                            match &opt.option {
457                                ColumnOption::NotNull => nullable = false,
458                                ColumnOption::Default(_) => has_default = true,
459                                ColumnOption::Unique { is_primary, .. } if *is_primary => {
460                                    is_pk = true;
461                                }
462                                _ => {}
463                            }
464                        }
465
466                        return ParsedStatement::AlterTableAddColumn {
467                            table,
468                            column: ColumnInfo {
469                                name: column_def.name.to_string(),
470                                data_type: column_def.data_type.to_string(),
471                                nullable,
472                                has_default,
473                                is_primary_key: is_pk,
474                            },
475                        };
476                    }
477
478                    // DROP COLUMN
479                    AlterTableOperation::DropColumn {
480                        column_name,
481                        if_exists,
482                        ..
483                    } => {
484                        return ParsedStatement::AlterTableDropColumn {
485                            table,
486                            column: column_name.to_string(),
487                            if_exists: *if_exists,
488                        };
489                    }
490
491                    // ALTER COLUMN TYPE
492                    AlterTableOperation::AlterColumn { column_name, op } => {
493                        use sqlparser::ast::AlterColumnOperation;
494                        match op {
495                            AlterColumnOperation::SetDataType { data_type, .. } => {
496                                return ParsedStatement::AlterTableAlterColumnType {
497                                    table,
498                                    column: column_name.to_string(),
499                                    new_type: data_type.to_string(),
500                                };
501                            }
502                            AlterColumnOperation::SetNotNull => {
503                                return ParsedStatement::AlterTableSetNotNull {
504                                    table,
505                                    column: column_name.to_string(),
506                                };
507                            }
508                            AlterColumnOperation::DropDefault => {
509                                return ParsedStatement::AlterTableAlterColumnDefault {
510                                    table,
511                                    column: column_name.to_string(),
512                                    drop_default: true,
513                                };
514                            }
515                            AlterColumnOperation::SetDefault { .. } => {
516                                return ParsedStatement::AlterTableAlterColumnDefault {
517                                    table,
518                                    column: column_name.to_string(),
519                                    drop_default: false,
520                                };
521                            }
522                            _ => {}
523                        }
524                    }
525
526                    // ADD CONSTRAINT (FK, PK, unique)
527                    AlterTableOperation::AddConstraint(constraint) => match constraint {
528                        TableConstraint::ForeignKey {
529                            name,
530                            columns: fk_cols,
531                            foreign_table,
532                            referred_columns,
533                            on_delete,
534                            on_update,
535                            ..
536                        } => {
537                            return ParsedStatement::AlterTableAddForeignKey {
538                                table,
539                                fk: ForeignKeyInfo {
540                                    columns: fk_cols.iter().map(|c| c.to_string()).collect(),
541                                    ref_table: foreign_table.to_string(),
542                                    ref_columns: referred_columns
543                                        .iter()
544                                        .map(|c| c.to_string())
545                                        .collect(),
546                                    on_delete_cascade: on_delete
547                                        .as_ref()
548                                        .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
549                                        .unwrap_or(false),
550                                    on_update_cascade: on_update
551                                        .as_ref()
552                                        .map(|a| a.to_string().to_uppercase().contains("CASCADE"))
553                                        .unwrap_or(false),
554                                    constraint_name: name.as_ref().map(|n| n.to_string()),
555                                },
556                            };
557                        }
558                        TableConstraint::PrimaryKey { columns, .. } => {
559                            return ParsedStatement::AlterTableAddPrimaryKey {
560                                table,
561                                columns: columns.iter().map(|c| c.to_string()).collect(),
562                            };
563                        }
564                        _ => {}
565                    },
566
567                    // DROP CONSTRAINT
568                    AlterTableOperation::DropConstraint { name, cascade, .. } => {
569                        return ParsedStatement::AlterTableDropConstraint {
570                            table,
571                            constraint: name.to_string(),
572                            cascade: *cascade,
573                        };
574                    }
575
576                    // RENAME COLUMN
577                    AlterTableOperation::RenameColumn {
578                        old_column_name,
579                        new_column_name,
580                    } => {
581                        return ParsedStatement::AlterTableRenameColumn {
582                            table,
583                            old: old_column_name.to_string(),
584                            new: new_column_name.to_string(),
585                        };
586                    }
587
588                    // RENAME TABLE
589                    AlterTableOperation::RenameTable { table_name } => {
590                        return ParsedStatement::AlterTableRenameTable {
591                            old: table,
592                            new: table_name.to_string(),
593                        };
594                    }
595
596                    _ => {}
597                }
598            }
599
600            // Unrecognised ALTER TABLE operation
601            ParsedStatement::Other {
602                raw: format!("ALTER TABLE {}", name),
603            }
604        }
605
606        other => ParsedStatement::Other {
607            raw: other.to_string().chars().take(80).collect(),
608        },
609    }
610}