Skip to main content

waypoint_core/
hooks.rs

1//! SQL callback hooks that run before/after migrations (Flyway-compatible).
2
3use std::collections::HashMap;
4use std::fmt;
5use std::path::PathBuf;
6
7use tokio_postgres::Client;
8
9use crate::config::HooksConfig;
10use crate::db;
11use crate::error::{Result, WaypointError};
12use crate::placeholder::replace_placeholders;
13
14/// The phase at which a hook runs.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum HookType {
17    /// Runs once before the entire migration run begins.
18    BeforeMigrate,
19    /// Runs once after the entire migration run completes.
20    AfterMigrate,
21    /// Runs before each individual migration is applied.
22    BeforeEachMigrate,
23    /// Runs after each individual migration is applied.
24    AfterEachMigrate,
25}
26
27impl fmt::Display for HookType {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            HookType::BeforeMigrate => write!(f, "beforeMigrate"),
31            HookType::AfterMigrate => write!(f, "afterMigrate"),
32            HookType::BeforeEachMigrate => write!(f, "beforeEachMigrate"),
33            HookType::AfterEachMigrate => write!(f, "afterEachMigrate"),
34        }
35    }
36}
37
38/// A hook SQL script discovered on disk or specified in config.
39#[derive(Debug, Clone)]
40pub struct ResolvedHook {
41    /// The phase at which this hook should be executed.
42    pub hook_type: HookType,
43    /// Filename of the hook SQL script.
44    pub script_name: String,
45    /// Raw SQL content of the hook file.
46    pub sql: String,
47}
48
49/// File prefixes that indicate hook callback files (Flyway-compatible).
50type HookPrefixEntry = (&'static str, fn() -> HookType);
51const HOOK_PREFIXES: &[HookPrefixEntry] = &[
52    ("beforeEachMigrate", || HookType::BeforeEachMigrate),
53    ("afterEachMigrate", || HookType::AfterEachMigrate),
54    ("beforeMigrate", || HookType::BeforeMigrate),
55    ("afterMigrate", || HookType::AfterMigrate),
56];
57
58/// Check if a filename is a hook callback file (not a migration).
59pub fn is_hook_file(filename: &str) -> bool {
60    HOOK_PREFIXES
61        .iter()
62        .any(|(prefix, _)| filename.starts_with(prefix) && filename.ends_with(".sql"))
63}
64
65/// Scan migration locations for SQL callback hook files.
66///
67/// Recognizes:
68///   - `beforeMigrate.sql` / `beforeMigrate__*.sql`
69///   - `afterMigrate.sql` / `afterMigrate__*.sql`
70///   - `beforeEachMigrate.sql` / `beforeEachMigrate__*.sql`
71///   - `afterEachMigrate.sql` / `afterEachMigrate__*.sql`
72///
73/// Multiple files per hook type are sorted alphabetically.
74pub fn scan_hooks(locations: &[PathBuf]) -> Result<Vec<ResolvedHook>> {
75    let mut hooks = Vec::new();
76
77    for location in locations {
78        if !location.exists() {
79            continue;
80        }
81
82        let entries = std::fs::read_dir(location).map_err(|e| {
83            WaypointError::IoError(std::io::Error::new(
84                e.kind(),
85                format!(
86                    "Failed to read hook directory '{}': {}",
87                    location.display(),
88                    e
89                ),
90            ))
91        })?;
92
93        let mut files: Vec<_> = entries
94            .filter_map(|e| e.ok())
95            .filter(|e| e.path().is_file())
96            .collect();
97
98        // Sort alphabetically for deterministic ordering
99        files.sort_by_key(|e| e.file_name());
100
101        for entry in files {
102            let path = entry.path();
103            let filename = match path.file_name().and_then(|n| n.to_str()) {
104                Some(name) => name.to_string(),
105                None => continue,
106            };
107
108            if !filename.ends_with(".sql") {
109                continue;
110            }
111
112            // Check each hook prefix
113            for (prefix, type_fn) in HOOK_PREFIXES {
114                if filename.starts_with(prefix) {
115                    // Must be exactly `prefix.sql` or `prefix__*.sql`
116                    let rest = &filename[prefix.len()..filename.len() - 4]; // strip prefix and .sql
117                    if rest.is_empty() || rest.starts_with("__") {
118                        let sql = std::fs::read_to_string(&path)?;
119                        hooks.push(ResolvedHook {
120                            hook_type: type_fn(),
121                            script_name: filename.clone(),
122                            sql,
123                        });
124                        break;
125                    }
126                }
127            }
128        }
129    }
130
131    // Sort within each hook type alphabetically by script name
132    hooks.sort_by(|a, b| {
133        a.hook_type
134            .to_string()
135            .cmp(&b.hook_type.to_string())
136            .then_with(|| a.script_name.cmp(&b.script_name))
137    });
138
139    Ok(hooks)
140}
141
142/// Load hook SQL files specified in the TOML `[hooks]` config section.
143pub fn load_config_hooks(config: &HooksConfig) -> Result<Vec<ResolvedHook>> {
144    let mut hooks = Vec::new();
145
146    let sections: &[(HookType, &[PathBuf])] = &[
147        (HookType::BeforeMigrate, &config.before_migrate),
148        (HookType::AfterMigrate, &config.after_migrate),
149        (HookType::BeforeEachMigrate, &config.before_each_migrate),
150        (HookType::AfterEachMigrate, &config.after_each_migrate),
151    ];
152
153    for (hook_type, paths) in sections {
154        for path in *paths {
155            let sql = std::fs::read_to_string(path).map_err(|e| {
156                WaypointError::IoError(std::io::Error::new(
157                    e.kind(),
158                    format!("Failed to read hook file '{}': {}", path.display(), e),
159                ))
160            })?;
161
162            let script_name = path
163                .file_name()
164                .and_then(|n| n.to_str())
165                .unwrap_or_else(|| path.to_str().unwrap_or("unknown"))
166                .to_string();
167
168            hooks.push(ResolvedHook {
169                hook_type: hook_type.clone(),
170                script_name,
171                sql,
172            });
173        }
174    }
175
176    Ok(hooks)
177}
178
179/// Run all hooks of a given type.
180///
181/// Returns total execution time in milliseconds.
182pub async fn run_hooks(
183    client: &Client,
184    hooks: &[ResolvedHook],
185    phase: &HookType,
186    placeholders: &HashMap<String, String>,
187) -> Result<(usize, i32)> {
188    let mut total_ms = 0;
189    let mut count = 0;
190
191    for hook in hooks.iter().filter(|h| &h.hook_type == phase) {
192        log::info!("Running {} hook: {}", phase, hook.script_name);
193
194        let sql = replace_placeholders(&hook.sql, placeholders)?;
195
196        match db::execute_in_transaction(client, &sql).await {
197            Ok(exec_time) => {
198                total_ms += exec_time;
199                count += 1;
200            }
201            Err(e) => {
202                let reason = match &e {
203                    WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
204                    other => other.to_string(),
205                };
206                return Err(WaypointError::HookFailed {
207                    phase: phase.to_string(),
208                    script: hook.script_name.clone(),
209                    reason,
210                });
211            }
212        }
213    }
214
215    Ok((count, total_ms))
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use std::fs;
222
223    fn create_temp_dir(name: &str) -> PathBuf {
224        let dir = std::env::temp_dir().join(format!("waypoint_hooks_test_{}", name));
225        let _ = fs::remove_dir_all(&dir);
226        fs::create_dir_all(&dir).unwrap();
227        dir
228    }
229
230    #[test]
231    fn test_is_hook_file() {
232        assert!(is_hook_file("beforeMigrate.sql"));
233        assert!(is_hook_file("afterMigrate.sql"));
234        assert!(is_hook_file("beforeEachMigrate.sql"));
235        assert!(is_hook_file("afterEachMigrate.sql"));
236        assert!(is_hook_file("beforeMigrate__Disable_triggers.sql"));
237        assert!(is_hook_file("afterMigrate__Refresh_views.sql"));
238
239        assert!(!is_hook_file("V1__Create_table.sql"));
240        assert!(!is_hook_file("R__Create_view.sql"));
241        assert!(!is_hook_file("beforeMigrate.txt"));
242        assert!(!is_hook_file("random.sql"));
243    }
244
245    #[test]
246    fn test_scan_hooks_finds_callback_files() {
247        let dir = create_temp_dir("scan");
248        fs::write(dir.join("beforeMigrate.sql"), "SELECT 1;").unwrap();
249        fs::write(dir.join("afterMigrate__Refresh_views.sql"), "SELECT 2;").unwrap();
250        fs::write(dir.join("V1__Create_table.sql"), "CREATE TABLE t(id INT);").unwrap();
251        fs::write(dir.join("R__Create_view.sql"), "CREATE VIEW v AS SELECT 1;").unwrap();
252
253        let hooks = scan_hooks(&[dir.clone()]).unwrap();
254
255        assert_eq!(hooks.len(), 2);
256
257        let before: Vec<_> = hooks
258            .iter()
259            .filter(|h| h.hook_type == HookType::BeforeMigrate)
260            .collect();
261        let after: Vec<_> = hooks
262            .iter()
263            .filter(|h| h.hook_type == HookType::AfterMigrate)
264            .collect();
265        assert_eq!(before.len(), 1);
266        assert_eq!(before[0].script_name, "beforeMigrate.sql");
267        assert_eq!(after.len(), 1);
268        assert_eq!(after[0].script_name, "afterMigrate__Refresh_views.sql");
269
270        let _ = fs::remove_dir_all(&dir);
271    }
272
273    #[test]
274    fn test_scan_hooks_multiple_sorted_alphabetically() {
275        let dir = create_temp_dir("multi");
276        fs::write(dir.join("beforeMigrate__B_second.sql"), "SELECT 2;").unwrap();
277        fs::write(dir.join("beforeMigrate__A_first.sql"), "SELECT 1;").unwrap();
278        fs::write(dir.join("beforeMigrate.sql"), "SELECT 0;").unwrap();
279
280        let hooks = scan_hooks(&[dir.clone()]).unwrap();
281
282        assert_eq!(hooks.len(), 3);
283        assert_eq!(hooks[0].script_name, "beforeMigrate.sql");
284        assert_eq!(hooks[1].script_name, "beforeMigrate__A_first.sql");
285        assert_eq!(hooks[2].script_name, "beforeMigrate__B_second.sql");
286
287        let _ = fs::remove_dir_all(&dir);
288    }
289
290    #[test]
291    fn test_load_config_hooks() {
292        let dir = create_temp_dir("config");
293        let hook_file = dir.join("pre.sql");
294        fs::write(&hook_file, "SET work_mem = '256MB';").unwrap();
295
296        let config = HooksConfig {
297            before_migrate: vec![hook_file],
298            after_migrate: vec![],
299            before_each_migrate: vec![],
300            after_each_migrate: vec![],
301        };
302
303        let hooks = load_config_hooks(&config).unwrap();
304        assert_eq!(hooks.len(), 1);
305        assert_eq!(hooks[0].hook_type, HookType::BeforeMigrate);
306        assert_eq!(hooks[0].sql, "SET work_mem = '256MB';");
307
308        let _ = fs::remove_dir_all(&dir);
309    }
310
311    #[test]
312    fn test_load_config_hooks_missing_file() {
313        let config = HooksConfig {
314            before_migrate: vec![PathBuf::from("/nonexistent/hook.sql")],
315            after_migrate: vec![],
316            before_each_migrate: vec![],
317            after_each_migrate: vec![],
318        };
319
320        assert!(load_config_hooks(&config).is_err());
321    }
322}