Skip to main content

safe_migrate/
ast.rs

1use crate::error::{Result, SafeMigrateError};
2use crate::model::{AlterAction, MigrationOp, SpannedOp};
3use crate::resolve::{extract_index_identity, extract_table_identity, normalize_ident};
4use squawk_syntax::ast::{self, AstNode};
5
6fn extract_alter_actions(alter_table: &ast::AlterTable) -> Vec<AlterAction> {
7    alter_table
8        .actions()
9        .map(|action| match action {
10            ast::AlterTableAction::AddColumn(_) => AlterAction::AddColumn,
11            ast::AlterTableAction::DropColumn(_) => AlterAction::DropColumn,
12            ast::AlterTableAction::AlterColumn(_) => AlterAction::AlterColumnUnspecified,
13            _ => AlterAction::Other,
14        })
15        .collect()
16}
17
18pub fn parse_and_classify(source_file: ast::SourceFile) -> Result<Vec<SpannedOp>> {
19    let mut ops = Vec::new();
20
21    for stmt in source_file.stmts() {
22        let syntax = stmt.syntax();
23        let range = syntax.text_range();
24        let start = u32::from(range.start());
25        let end = u32::from(range.end());
26
27        // Refactored to return the ops array directly to avoid `mut` closure warnings
28        let parse_op = || -> Result<Vec<MigrationOp>> {
29            let mut local_ops = Vec::new();
30
31            // 1. Whitelist Safe Statements
32            if ast::Begin::cast(syntax.clone()).is_some() {
33                local_ops.push(MigrationOp::Ignored("BEGIN".into()));
34                return Ok(local_ops);
35            }
36            if ast::Commit::cast(syntax.clone()).is_some() {
37                local_ops.push(MigrationOp::Ignored("COMMIT".into()));
38                return Ok(local_ops);
39            }
40            if ast::Analyze::cast(syntax.clone()).is_some() {
41                local_ops.push(MigrationOp::Ignored("ANALYZE".into()));
42                return Ok(local_ops);
43            }
44            if ast::Set::cast(syntax.clone()).is_some() {
45                local_ops.push(MigrationOp::Ignored("SET".into()));
46                return Ok(local_ops);
47            }
48
49            // FIX: Safely bypass DML (Data Migrations) using the exact AST structs
50            if ast::Insert::cast(syntax.clone()).is_some()
51                || ast::Update::cast(syntax.clone()).is_some()
52                || ast::Delete::cast(syntax.clone()).is_some()
53                || ast::Select::cast(syntax.clone()).is_some()
54            {
55                local_ops.push(MigrationOp::Ignored("DML".into()));
56                return Ok(local_ops);
57            }
58
59            // 2. Table Operations
60            if let Some(drop_table) = ast::DropTable::cast(syntax.clone()) {
61                if drop_table.comma_token().is_some() {
62                    return Err(SafeMigrateError::Parse("Multi-table DROP TABLE is not safely verified yet. Split into multiple statements.".into()));
63                }
64
65                let path = drop_table
66                    .path()
67                    .ok_or_else(|| SafeMigrateError::Parse("DropTable missing path".into()))?;
68
69                local_ops.push(MigrationOp::DropTable(extract_table_identity(path)?));
70                return Ok(local_ops);
71            }
72
73            if let Some(create_table) = ast::CreateTable::cast(syntax.clone()) {
74                let path = create_table
75                    .path()
76                    .ok_or_else(|| SafeMigrateError::Parse("CreateTable missing path".into()))?;
77
78                local_ops.push(MigrationOp::CreateTable(extract_table_identity(path)?));
79                return Ok(local_ops);
80            }
81
82            // 3. Index Operations
83            if let Some(drop_index) = ast::DropIndex::cast(syntax.clone()) {
84                let concurrently = drop_index.concurrently_token().is_some();
85                let mut indexes = Vec::new();
86                for path in drop_index.paths() {
87                    indexes.push(extract_index_identity(path)?);
88                }
89                if indexes.is_empty() {
90                    return Err(SafeMigrateError::Parse("DropIndex missing paths".into()));
91                }
92                local_ops.push(MigrationOp::DropIndex {
93                    indexes,
94                    concurrently,
95                });
96                return Ok(local_ops);
97            }
98
99            if let Some(create_index) = ast::CreateIndex::cast(syntax.clone()) {
100                let table_path = create_index
101                    .relation_name()
102                    .and_then(|rel| rel.path())
103                    .ok_or_else(|| {
104                        SafeMigrateError::Parse("CreateIndex missing target table".into())
105                    })?;
106
107                let index_name = create_index
108                    .name()
109                    .and_then(|n| n.ident_token().map(|t| normalize_ident(t.text())));
110
111                local_ops.push(MigrationOp::CreateIndex {
112                    index_name,
113                    table: extract_table_identity(table_path)?,
114                    concurrently: create_index.concurrently_token().is_some(),
115                });
116                return Ok(local_ops);
117            }
118
119            // 4. Alter Table
120            if let Some(alter_table) = ast::AlterTable::cast(syntax.clone()) {
121                let path = alter_table
122                    .relation_name()
123                    .and_then(|rel| rel.path())
124                    .ok_or_else(|| {
125                        SafeMigrateError::Parse("AlterTable missing relation name".into())
126                    })?;
127
128                local_ops.push(MigrationOp::AlterTable {
129                    table: extract_table_identity(path)?,
130                    actions: extract_alter_actions(&alter_table),
131                });
132                return Ok(local_ops);
133            }
134
135            Err(SafeMigrateError::Parse(
136                "Statement type not explicitly supported".into(),
137            ))
138        };
139
140        match parse_op() {
141            Ok(parsed_ops) => {
142                for op in parsed_ops {
143                    ops.push(SpannedOp { op, start, end });
144                }
145            }
146            Err(e) => {
147                ops.push(SpannedOp {
148                    op: MigrationOp::Unknown {
149                        raw: syntax.text().to_string(),
150                        reason: e.to_string(),
151                    },
152                    start,
153                    end,
154                });
155            }
156        }
157    }
158
159    Ok(ops)
160}