Skip to main content

safe_migrate/
rules.rs

1use crate::config::{Config, get_recipe};
2use crate::model::{AlterAction, CacheData, LintRecord, LockTier, MigrationOp, SpannedOp};
3
4pub fn evaluate(
5    file_path: &str,
6    spanned_ops: Vec<SpannedOp>,
7    cache: &CacheData,
8    default_schema: &str,
9    config: &Config,
10) -> Vec<LintRecord> {
11    let mut records = Vec::new();
12
13    for spanned in spanned_ops {
14        let SpannedOp { op, start, end } = spanned;
15
16        let mut evaluate_rule = |rule_key: &str, is_large: bool, base_msg: String| {
17            let rule_cfg = config.rules.get(rule_key);
18            let tier = if is_large {
19                rule_cfg.map(|r| r.tier.clone()).unwrap_or(LockTier::Tier1)
20            } else {
21                LockTier::Tier3
22            };
23
24            let recipe = if rule_cfg.is_some() {
25                get_recipe(rule_key).to_string()
26            } else {
27                String::new()
28            };
29
30            records.push(LintRecord {
31                file: file_path.to_string(),
32                start,
33                end,
34                tier,
35                op: op.clone(),
36                message: base_msg,
37                rule_name: rule_key.to_string(),
38                recipe,
39            });
40        };
41
42        match &op {
43            MigrationOp::Ignored(cmd) => {
44                evaluate_rule(
45                    "benign-statement",
46                    false,
47                    format!("Benign statement '{}' is ignored.", cmd),
48                );
49            }
50            MigrationOp::CreateTable(table) => {
51                evaluate_rule(
52                    "create-table",
53                    false,
54                    format!(
55                        "CREATE TABLE '{}' is safe.",
56                        table.canonical_key(default_schema)
57                    ),
58                );
59            }
60            MigrationOp::Unknown { reason, .. } => {
61                evaluate_rule(
62                    "executing-unclassified-statement",
63                    true,
64                    format!("Unclassified statement: {}", reason),
65                );
66            }
67            MigrationOp::DropTable(table) => {
68                let key = table.canonical_key(default_schema);
69                let rows = cache
70                    .tables
71                    .get(&key)
72                    .map(|s| s.estimated_rows)
73                    .unwrap_or(u64::MAX);
74                let threshold = config
75                    .rules
76                    .get("ban-drop-table")
77                    .and_then(|r| r.threshold)
78                    .unwrap_or(config.default_threshold);
79                evaluate_rule(
80                    "ban-drop-table",
81                    rows > threshold,
82                    format!("Dropping table '{}' (~{} rows).", key, rows),
83                );
84            }
85            MigrationOp::CreateIndex {
86                table,
87                concurrently,
88                ..
89            } => {
90                let key = table.canonical_key(default_schema);
91                let rows = cache
92                    .tables
93                    .get(&key)
94                    .map(|s| s.estimated_rows)
95                    .unwrap_or(u64::MAX);
96                let threshold = config
97                    .rules
98                    .get("require-concurrent-index-creation")
99                    .and_then(|r| r.threshold)
100                    .unwrap_or(config.default_threshold);
101
102                if !concurrently {
103                    evaluate_rule(
104                        "require-concurrent-index-creation",
105                        rows > threshold,
106                        format!("Building index on '{}' without CONCURRENTLY.", key),
107                    );
108                } else {
109                    evaluate_rule(
110                        "require-concurrent-index-creation",
111                        false,
112                        format!("Building index on '{}' with CONCURRENTLY is safe.", key),
113                    );
114                }
115            }
116            MigrationOp::DropIndex {
117                indexes,
118                concurrently,
119            } => {
120                let threshold = config
121                    .rules
122                    .get("require-concurrent-index-deletion")
123                    .and_then(|r| r.threshold)
124                    .unwrap_or(config.default_threshold);
125
126                for index in indexes {
127                    let key = index.canonical_key(default_schema);
128                    let rows = cache
129                        .indexes
130                        .get(&key)
131                        .and_then(|table_key| cache.tables.get(table_key))
132                        .map(|s| s.estimated_rows)
133                        .unwrap_or(u64::MAX);
134
135                    if !concurrently {
136                        evaluate_rule(
137                            "require-concurrent-index-deletion",
138                            rows > threshold,
139                            format!("Dropping index '{}' without CONCURRENTLY.", key),
140                        );
141                    } else {
142                        evaluate_rule(
143                            "require-concurrent-index-deletion",
144                            false,
145                            format!("Dropping index '{}' with CONCURRENTLY is safe.", key),
146                        );
147                    }
148                }
149            }
150            MigrationOp::AlterTable { table, actions } => {
151                let key = table.canonical_key(default_schema);
152                let rows = cache
153                    .tables
154                    .get(&key)
155                    .map(|s| s.estimated_rows)
156                    .unwrap_or(u64::MAX);
157
158                for action in actions {
159                    // Removed the unreachable catch-all so the compiler strictly checks all future variants
160                    match action {
161                        AlterAction::AddColumn => {
162                            let threshold = config
163                                .rules
164                                .get("adding-field-with-default")
165                                .and_then(|r| r.threshold)
166                                .unwrap_or(config.default_threshold);
167                            evaluate_rule(
168                                "adding-field-with-default",
169                                rows > threshold,
170                                format!(
171                                    "Adding column to '{}'. Verify it lacks a VOLATILE default.",
172                                    key
173                                ),
174                            );
175                        }
176                        AlterAction::AlterColumnUnspecified => {
177                            let threshold = config
178                                .rules
179                                .get("changing-column-type")
180                                .and_then(|r| r.threshold)
181                                .unwrap_or(config.default_threshold);
182                            evaluate_rule(
183                                "changing-column-type",
184                                rows > threshold,
185                                format!("Altering column on '{}'.", key),
186                            );
187                        }
188                        AlterAction::DropColumn => {
189                            let threshold = config
190                                .rules
191                                .get("ban-drop-column")
192                                .and_then(|r| r.threshold)
193                                .unwrap_or(config.default_threshold);
194                            evaluate_rule(
195                                "ban-drop-column",
196                                rows > threshold,
197                                format!("Dropping column from '{}'.", key),
198                            );
199                        }
200                        AlterAction::Other => {
201                            evaluate_rule(
202                                "executing-unclassified-statement",
203                                true,
204                                format!("Unclassified ALTER TABLE operation on '{}'.", key),
205                            );
206                        }
207                    }
208                }
209            }
210        }
211    }
212
213    records
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::ast::parse_and_classify;
220    use crate::config::Config;
221    use crate::model::{CacheData, CacheEntry};
222    use squawk_syntax::ast::SourceFile;
223    use std::collections::HashMap;
224
225    fn setup_mock_env() -> (CacheData, Config) {
226        let mut tables = HashMap::new();
227        tables.insert(
228            "public.users".to_string(),
229            CacheEntry {
230                estimated_rows: 5_000_000,
231                relpages: Some(1000),
232            },
233        );
234        tables.insert(
235            "public.config".to_string(),
236            CacheEntry {
237                estimated_rows: 50,
238                relpages: Some(1),
239            },
240        );
241
242        let mut indexes = HashMap::new();
243        indexes.insert(
244            "public.idx_users_email".to_string(),
245            "public.users".to_string(),
246        );
247
248        let cache = CacheData {
249            last_updated: 0,
250            tables,
251            indexes,
252        };
253        let config = Config::default_config();
254
255        (cache, config)
256    }
257
258    fn run_lint(sql: &str, cache: &CacheData, config: &Config) -> Vec<LintRecord> {
259        let parse_result = SourceFile::parse(sql);
260        let ops = parse_and_classify(parse_result.tree()).expect("AST parse failed");
261        evaluate("test.sql", ops, cache, "public", config)
262    }
263
264    #[test]
265    fn test_safe_statements_ignored() {
266        let (cache, config) = setup_mock_env();
267        let records = run_lint(
268            "BEGIN; COMMIT; SET statement_timeout = '2s';",
269            &cache,
270            &config,
271        );
272        assert_eq!(records.len(), 3);
273        assert!(records.iter().all(|r| r.tier == LockTier::Tier3));
274    }
275
276    #[test]
277    fn test_add_column_large_table_fails() {
278        let (cache, config) = setup_mock_env();
279        let records = run_lint("ALTER TABLE users ADD COLUMN bio TEXT;", &cache, &config);
280        assert_eq!(records.len(), 1);
281        assert_eq!(records[0].tier, LockTier::Tier1);
282        assert_eq!(records[0].rule_name, "adding-field-with-default");
283    }
284
285    #[test]
286    fn test_add_column_small_table_passes() {
287        let (cache, config) = setup_mock_env();
288        let records = run_lint(
289            "ALTER TABLE config ADD COLUMN flag BOOLEAN;",
290            &cache,
291            &config,
292        );
293        assert_eq!(records.len(), 1);
294        assert_eq!(records[0].tier, LockTier::Tier3);
295    }
296
297    #[test]
298    fn test_drop_index_concurrent_logic() {
299        let (cache, config) = setup_mock_env();
300
301        let bad = run_lint("DROP INDEX idx_users_email;", &cache, &config);
302        assert_eq!(bad[0].tier, LockTier::Tier2);
303        assert_eq!(bad[0].rule_name, "require-concurrent-index-deletion");
304
305        let good = run_lint("DROP INDEX CONCURRENTLY idx_users_email;", &cache, &config);
306        assert_eq!(good.len(), 1);
307        assert_eq!(good[0].tier, LockTier::Tier3);
308    }
309
310    #[test]
311    fn test_multi_table_drop_guardrail() {
312        let sql = "DROP TABLE users, config;";
313        let parse_result = SourceFile::parse(sql);
314        let ops = parse_and_classify(parse_result.tree()).expect("AST parse failed");
315
316        assert_eq!(ops.len(), 1);
317        match &ops[0].op {
318            crate::model::MigrationOp::Unknown { reason, .. } => {
319                assert!(reason.contains("Multi-table DROP TABLE is not safely verified"));
320            }
321            _ => panic!("Parser failed to catch multi-table DROP and output Unknown"),
322        }
323    }
324}