Skip to main content

waypoint_core/
hooks.rs

1//! SQL callback hooks that run before/after migrations (Flyway-compatible).
2
3use std::collections::HashMap;
4use std::fmt;
5use std::path::PathBuf;
6
7use tokio_postgres::Client;
8
9use crate::config::{HooksConfig, WaypointConfig};
10use crate::db;
11use crate::error::{Result, WaypointError};
12use crate::placeholder::replace_placeholders;
13
14/// The phase at which a hook runs.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum HookType {
17    BeforeMigrate,
18    AfterMigrate,
19    BeforeEachMigrate,
20    AfterEachMigrate,
21}
22
23impl fmt::Display for HookType {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            HookType::BeforeMigrate => write!(f, "beforeMigrate"),
27            HookType::AfterMigrate => write!(f, "afterMigrate"),
28            HookType::BeforeEachMigrate => write!(f, "beforeEachMigrate"),
29            HookType::AfterEachMigrate => write!(f, "afterEachMigrate"),
30        }
31    }
32}
33
34/// A hook SQL script discovered on disk or specified in config.
35#[derive(Debug, Clone)]
36pub struct ResolvedHook {
37    pub hook_type: HookType,
38    pub script_name: String,
39    pub sql: String,
40}
41
42/// File prefixes that indicate hook callback files (Flyway-compatible).
43type HookPrefixEntry = (&'static str, fn() -> HookType);
44const HOOK_PREFIXES: &[HookPrefixEntry] = &[
45    ("beforeEachMigrate", || HookType::BeforeEachMigrate),
46    ("afterEachMigrate", || HookType::AfterEachMigrate),
47    ("beforeMigrate", || HookType::BeforeMigrate),
48    ("afterMigrate", || HookType::AfterMigrate),
49];
50
51/// Check if a filename is a hook callback file (not a migration).
52pub fn is_hook_file(filename: &str) -> bool {
53    HOOK_PREFIXES
54        .iter()
55        .any(|(prefix, _)| filename.starts_with(prefix) && filename.ends_with(".sql"))
56}
57
58/// Scan migration locations for SQL callback hook files.
59///
60/// Recognizes:
61///   - `beforeMigrate.sql` / `beforeMigrate__*.sql`
62///   - `afterMigrate.sql` / `afterMigrate__*.sql`
63///   - `beforeEachMigrate.sql` / `beforeEachMigrate__*.sql`
64///   - `afterEachMigrate.sql` / `afterEachMigrate__*.sql`
65///
66/// Multiple files per hook type are sorted alphabetically.
67pub fn scan_hooks(locations: &[PathBuf]) -> Result<Vec<ResolvedHook>> {
68    let mut hooks = Vec::new();
69
70    for location in locations {
71        if !location.exists() {
72            continue;
73        }
74
75        let entries = std::fs::read_dir(location).map_err(|e| {
76            WaypointError::IoError(std::io::Error::new(
77                e.kind(),
78                format!(
79                    "Failed to read hook directory '{}': {}",
80                    location.display(),
81                    e
82                ),
83            ))
84        })?;
85
86        let mut files: Vec<_> = entries
87            .filter_map(|e| e.ok())
88            .filter(|e| e.path().is_file())
89            .collect();
90
91        // Sort alphabetically for deterministic ordering
92        files.sort_by_key(|e| e.file_name());
93
94        for entry in files {
95            let path = entry.path();
96            let filename = match path.file_name().and_then(|n| n.to_str()) {
97                Some(name) => name.to_string(),
98                None => continue,
99            };
100
101            if !filename.ends_with(".sql") {
102                continue;
103            }
104
105            // Check each hook prefix
106            for (prefix, type_fn) in HOOK_PREFIXES {
107                if filename.starts_with(prefix) {
108                    // Must be exactly `prefix.sql` or `prefix__*.sql`
109                    let rest = &filename[prefix.len()..filename.len() - 4]; // strip prefix and .sql
110                    if rest.is_empty() || rest.starts_with("__") {
111                        let sql = std::fs::read_to_string(&path)?;
112                        hooks.push(ResolvedHook {
113                            hook_type: type_fn(),
114                            script_name: filename.clone(),
115                            sql,
116                        });
117                        break;
118                    }
119                }
120            }
121        }
122    }
123
124    // Sort within each hook type alphabetically by script name
125    hooks.sort_by(|a, b| {
126        a.hook_type
127            .to_string()
128            .cmp(&b.hook_type.to_string())
129            .then_with(|| a.script_name.cmp(&b.script_name))
130    });
131
132    Ok(hooks)
133}
134
135/// Load hook SQL files specified in the TOML `[hooks]` config section.
136pub fn load_config_hooks(config: &HooksConfig) -> Result<Vec<ResolvedHook>> {
137    let mut hooks = Vec::new();
138
139    let sections: &[(HookType, &[PathBuf])] = &[
140        (HookType::BeforeMigrate, &config.before_migrate),
141        (HookType::AfterMigrate, &config.after_migrate),
142        (HookType::BeforeEachMigrate, &config.before_each_migrate),
143        (HookType::AfterEachMigrate, &config.after_each_migrate),
144    ];
145
146    for (hook_type, paths) in sections {
147        for path in *paths {
148            let sql = std::fs::read_to_string(path).map_err(|e| {
149                WaypointError::IoError(std::io::Error::new(
150                    e.kind(),
151                    format!("Failed to read hook file '{}': {}", path.display(), e),
152                ))
153            })?;
154
155            let script_name = path
156                .file_name()
157                .and_then(|n| n.to_str())
158                .unwrap_or_else(|| path.to_str().unwrap_or("unknown"))
159                .to_string();
160
161            hooks.push(ResolvedHook {
162                hook_type: hook_type.clone(),
163                script_name,
164                sql,
165            });
166        }
167    }
168
169    Ok(hooks)
170}
171
172/// Run all hooks of a given type.
173///
174/// Returns total execution time in milliseconds.
175pub async fn run_hooks(
176    client: &Client,
177    _config: &WaypointConfig,
178    hooks: &[ResolvedHook],
179    phase: &HookType,
180    placeholders: &HashMap<String, String>,
181) -> Result<(usize, i32)> {
182    let mut total_ms = 0;
183    let mut count = 0;
184
185    for hook in hooks.iter().filter(|h| &h.hook_type == phase) {
186        tracing::info!("Running {} hook: {}", phase, hook.script_name);
187
188        let sql = replace_placeholders(&hook.sql, placeholders)?;
189
190        match db::execute_in_transaction(client, &sql).await {
191            Ok(exec_time) => {
192                total_ms += exec_time;
193                count += 1;
194            }
195            Err(e) => {
196                let reason = match &e {
197                    WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
198                    other => other.to_string(),
199                };
200                return Err(WaypointError::HookFailed {
201                    phase: phase.to_string(),
202                    script: hook.script_name.clone(),
203                    reason,
204                });
205            }
206        }
207    }
208
209    Ok((count, total_ms))
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use std::fs;
216
217    fn create_temp_dir(name: &str) -> PathBuf {
218        let dir = std::env::temp_dir().join(format!("waypoint_hooks_test_{}", name));
219        let _ = fs::remove_dir_all(&dir);
220        fs::create_dir_all(&dir).unwrap();
221        dir
222    }
223
224    #[test]
225    fn test_is_hook_file() {
226        assert!(is_hook_file("beforeMigrate.sql"));
227        assert!(is_hook_file("afterMigrate.sql"));
228        assert!(is_hook_file("beforeEachMigrate.sql"));
229        assert!(is_hook_file("afterEachMigrate.sql"));
230        assert!(is_hook_file("beforeMigrate__Disable_triggers.sql"));
231        assert!(is_hook_file("afterMigrate__Refresh_views.sql"));
232
233        assert!(!is_hook_file("V1__Create_table.sql"));
234        assert!(!is_hook_file("R__Create_view.sql"));
235        assert!(!is_hook_file("beforeMigrate.txt"));
236        assert!(!is_hook_file("random.sql"));
237    }
238
239    #[test]
240    fn test_scan_hooks_finds_callback_files() {
241        let dir = create_temp_dir("scan");
242        fs::write(dir.join("beforeMigrate.sql"), "SELECT 1;").unwrap();
243        fs::write(dir.join("afterMigrate__Refresh_views.sql"), "SELECT 2;").unwrap();
244        fs::write(dir.join("V1__Create_table.sql"), "CREATE TABLE t(id INT);").unwrap();
245        fs::write(dir.join("R__Create_view.sql"), "CREATE VIEW v AS SELECT 1;").unwrap();
246
247        let hooks = scan_hooks(&[dir.clone()]).unwrap();
248
249        assert_eq!(hooks.len(), 2);
250
251        let before: Vec<_> = hooks
252            .iter()
253            .filter(|h| h.hook_type == HookType::BeforeMigrate)
254            .collect();
255        let after: Vec<_> = hooks
256            .iter()
257            .filter(|h| h.hook_type == HookType::AfterMigrate)
258            .collect();
259        assert_eq!(before.len(), 1);
260        assert_eq!(before[0].script_name, "beforeMigrate.sql");
261        assert_eq!(after.len(), 1);
262        assert_eq!(after[0].script_name, "afterMigrate__Refresh_views.sql");
263
264        let _ = fs::remove_dir_all(&dir);
265    }
266
267    #[test]
268    fn test_scan_hooks_multiple_sorted_alphabetically() {
269        let dir = create_temp_dir("multi");
270        fs::write(dir.join("beforeMigrate__B_second.sql"), "SELECT 2;").unwrap();
271        fs::write(dir.join("beforeMigrate__A_first.sql"), "SELECT 1;").unwrap();
272        fs::write(dir.join("beforeMigrate.sql"), "SELECT 0;").unwrap();
273
274        let hooks = scan_hooks(&[dir.clone()]).unwrap();
275
276        assert_eq!(hooks.len(), 3);
277        assert_eq!(hooks[0].script_name, "beforeMigrate.sql");
278        assert_eq!(hooks[1].script_name, "beforeMigrate__A_first.sql");
279        assert_eq!(hooks[2].script_name, "beforeMigrate__B_second.sql");
280
281        let _ = fs::remove_dir_all(&dir);
282    }
283
284    #[test]
285    fn test_load_config_hooks() {
286        let dir = create_temp_dir("config");
287        let hook_file = dir.join("pre.sql");
288        fs::write(&hook_file, "SET work_mem = '256MB';").unwrap();
289
290        let config = HooksConfig {
291            before_migrate: vec![hook_file],
292            after_migrate: vec![],
293            before_each_migrate: vec![],
294            after_each_migrate: vec![],
295        };
296
297        let hooks = load_config_hooks(&config).unwrap();
298        assert_eq!(hooks.len(), 1);
299        assert_eq!(hooks[0].hook_type, HookType::BeforeMigrate);
300        assert_eq!(hooks[0].sql, "SET work_mem = '256MB';");
301
302        let _ = fs::remove_dir_all(&dir);
303    }
304
305    #[test]
306    fn test_load_config_hooks_missing_file() {
307        let config = HooksConfig {
308            before_migrate: vec![PathBuf::from("/nonexistent/hook.sql")],
309            after_migrate: vec![],
310            before_each_migrate: vec![],
311            after_each_migrate: vec![],
312        };
313
314        assert!(load_config_hooks(&config).is_err());
315    }
316}