Skip to main content

waypoint_core/
hooks.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4
5use tokio_postgres::Client;
6
7use crate::config::{HooksConfig, WaypointConfig};
8use crate::db;
9use crate::error::{Result, WaypointError};
10use crate::placeholder::replace_placeholders;
11
12/// The phase at which a hook runs.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum HookType {
15    BeforeMigrate,
16    AfterMigrate,
17    BeforeEachMigrate,
18    AfterEachMigrate,
19}
20
21impl fmt::Display for HookType {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            HookType::BeforeMigrate => write!(f, "beforeMigrate"),
25            HookType::AfterMigrate => write!(f, "afterMigrate"),
26            HookType::BeforeEachMigrate => write!(f, "beforeEachMigrate"),
27            HookType::AfterEachMigrate => write!(f, "afterEachMigrate"),
28        }
29    }
30}
31
32/// A hook SQL script discovered on disk or specified in config.
33#[derive(Debug, Clone)]
34pub struct ResolvedHook {
35    pub hook_type: HookType,
36    pub script_name: String,
37    pub sql: String,
38}
39
40/// File prefixes that indicate hook callback files (Flyway-compatible).
41type HookPrefixEntry = (&'static str, fn() -> HookType);
42const HOOK_PREFIXES: &[HookPrefixEntry] = &[
43    ("beforeEachMigrate", || HookType::BeforeEachMigrate),
44    ("afterEachMigrate", || HookType::AfterEachMigrate),
45    ("beforeMigrate", || HookType::BeforeMigrate),
46    ("afterMigrate", || HookType::AfterMigrate),
47];
48
49/// Check if a filename is a hook callback file (not a migration).
50pub fn is_hook_file(filename: &str) -> bool {
51    HOOK_PREFIXES
52        .iter()
53        .any(|(prefix, _)| filename.starts_with(prefix) && filename.ends_with(".sql"))
54}
55
56/// Scan migration locations for SQL callback hook files.
57///
58/// Recognizes:
59///   - `beforeMigrate.sql` / `beforeMigrate__*.sql`
60///   - `afterMigrate.sql` / `afterMigrate__*.sql`
61///   - `beforeEachMigrate.sql` / `beforeEachMigrate__*.sql`
62///   - `afterEachMigrate.sql` / `afterEachMigrate__*.sql`
63///
64/// Multiple files per hook type are sorted alphabetically.
65pub fn scan_hooks(locations: &[PathBuf]) -> Result<Vec<ResolvedHook>> {
66    let mut hooks = Vec::new();
67
68    for location in locations {
69        if !location.exists() {
70            continue;
71        }
72
73        let entries = std::fs::read_dir(location).map_err(|e| {
74            WaypointError::IoError(std::io::Error::new(
75                e.kind(),
76                format!(
77                    "Failed to read hook directory '{}': {}",
78                    location.display(),
79                    e
80                ),
81            ))
82        })?;
83
84        let mut files: Vec<_> = entries
85            .filter_map(|e| e.ok())
86            .filter(|e| e.path().is_file())
87            .collect();
88
89        // Sort alphabetically for deterministic ordering
90        files.sort_by_key(|e| e.file_name());
91
92        for entry in files {
93            let path = entry.path();
94            let filename = match path.file_name().and_then(|n| n.to_str()) {
95                Some(name) => name.to_string(),
96                None => continue,
97            };
98
99            if !filename.ends_with(".sql") {
100                continue;
101            }
102
103            // Check each hook prefix
104            for (prefix, type_fn) in HOOK_PREFIXES {
105                if filename.starts_with(prefix) {
106                    // Must be exactly `prefix.sql` or `prefix__*.sql`
107                    let rest = &filename[prefix.len()..filename.len() - 4]; // strip prefix and .sql
108                    if rest.is_empty() || rest.starts_with("__") {
109                        let sql = std::fs::read_to_string(&path)?;
110                        hooks.push(ResolvedHook {
111                            hook_type: type_fn(),
112                            script_name: filename.clone(),
113                            sql,
114                        });
115                        break;
116                    }
117                }
118            }
119        }
120    }
121
122    // Sort within each hook type alphabetically by script name
123    hooks.sort_by(|a, b| {
124        a.hook_type
125            .to_string()
126            .cmp(&b.hook_type.to_string())
127            .then_with(|| a.script_name.cmp(&b.script_name))
128    });
129
130    Ok(hooks)
131}
132
133/// Load hook SQL files specified in the TOML `[hooks]` config section.
134pub fn load_config_hooks(config: &HooksConfig) -> Result<Vec<ResolvedHook>> {
135    let mut hooks = Vec::new();
136
137    let sections: &[(HookType, &[PathBuf])] = &[
138        (HookType::BeforeMigrate, &config.before_migrate),
139        (HookType::AfterMigrate, &config.after_migrate),
140        (HookType::BeforeEachMigrate, &config.before_each_migrate),
141        (HookType::AfterEachMigrate, &config.after_each_migrate),
142    ];
143
144    for (hook_type, paths) in sections {
145        for path in *paths {
146            let sql = std::fs::read_to_string(path).map_err(|e| {
147                WaypointError::IoError(std::io::Error::new(
148                    e.kind(),
149                    format!("Failed to read hook file '{}': {}", path.display(), e),
150                ))
151            })?;
152
153            let script_name = path
154                .file_name()
155                .and_then(|n| n.to_str())
156                .unwrap_or_else(|| path.to_str().unwrap_or("unknown"))
157                .to_string();
158
159            hooks.push(ResolvedHook {
160                hook_type: hook_type.clone(),
161                script_name,
162                sql,
163            });
164        }
165    }
166
167    Ok(hooks)
168}
169
170/// Run all hooks of a given type.
171///
172/// Returns total execution time in milliseconds.
173pub async fn run_hooks(
174    client: &Client,
175    _config: &WaypointConfig,
176    hooks: &[ResolvedHook],
177    phase: &HookType,
178    placeholders: &HashMap<String, String>,
179) -> Result<(usize, i32)> {
180    let mut total_ms = 0;
181    let mut count = 0;
182
183    for hook in hooks.iter().filter(|h| &h.hook_type == phase) {
184        tracing::info!("Running {} hook: {}", phase, hook.script_name);
185
186        let sql = replace_placeholders(&hook.sql, placeholders)?;
187
188        match db::execute_in_transaction(client, &sql).await {
189            Ok(exec_time) => {
190                total_ms += exec_time;
191                count += 1;
192            }
193            Err(e) => {
194                let reason = match &e {
195                    WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
196                    other => other.to_string(),
197                };
198                return Err(WaypointError::HookFailed {
199                    phase: phase.to_string(),
200                    script: hook.script_name.clone(),
201                    reason,
202                });
203            }
204        }
205    }
206
207    Ok((count, total_ms))
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use std::fs;
214
215    fn create_temp_dir(name: &str) -> PathBuf {
216        let dir = std::env::temp_dir().join(format!("waypoint_hooks_test_{}", name));
217        let _ = fs::remove_dir_all(&dir);
218        fs::create_dir_all(&dir).unwrap();
219        dir
220    }
221
222    #[test]
223    fn test_is_hook_file() {
224        assert!(is_hook_file("beforeMigrate.sql"));
225        assert!(is_hook_file("afterMigrate.sql"));
226        assert!(is_hook_file("beforeEachMigrate.sql"));
227        assert!(is_hook_file("afterEachMigrate.sql"));
228        assert!(is_hook_file("beforeMigrate__Disable_triggers.sql"));
229        assert!(is_hook_file("afterMigrate__Refresh_views.sql"));
230
231        assert!(!is_hook_file("V1__Create_table.sql"));
232        assert!(!is_hook_file("R__Create_view.sql"));
233        assert!(!is_hook_file("beforeMigrate.txt"));
234        assert!(!is_hook_file("random.sql"));
235    }
236
237    #[test]
238    fn test_scan_hooks_finds_callback_files() {
239        let dir = create_temp_dir("scan");
240        fs::write(dir.join("beforeMigrate.sql"), "SELECT 1;").unwrap();
241        fs::write(dir.join("afterMigrate__Refresh_views.sql"), "SELECT 2;").unwrap();
242        fs::write(dir.join("V1__Create_table.sql"), "CREATE TABLE t(id INT);").unwrap();
243        fs::write(dir.join("R__Create_view.sql"), "CREATE VIEW v AS SELECT 1;").unwrap();
244
245        let hooks = scan_hooks(&[dir.clone()]).unwrap();
246
247        assert_eq!(hooks.len(), 2);
248
249        let before: Vec<_> = hooks
250            .iter()
251            .filter(|h| h.hook_type == HookType::BeforeMigrate)
252            .collect();
253        let after: Vec<_> = hooks
254            .iter()
255            .filter(|h| h.hook_type == HookType::AfterMigrate)
256            .collect();
257        assert_eq!(before.len(), 1);
258        assert_eq!(before[0].script_name, "beforeMigrate.sql");
259        assert_eq!(after.len(), 1);
260        assert_eq!(after[0].script_name, "afterMigrate__Refresh_views.sql");
261
262        let _ = fs::remove_dir_all(&dir);
263    }
264
265    #[test]
266    fn test_scan_hooks_multiple_sorted_alphabetically() {
267        let dir = create_temp_dir("multi");
268        fs::write(dir.join("beforeMigrate__B_second.sql"), "SELECT 2;").unwrap();
269        fs::write(dir.join("beforeMigrate__A_first.sql"), "SELECT 1;").unwrap();
270        fs::write(dir.join("beforeMigrate.sql"), "SELECT 0;").unwrap();
271
272        let hooks = scan_hooks(&[dir.clone()]).unwrap();
273
274        assert_eq!(hooks.len(), 3);
275        assert_eq!(hooks[0].script_name, "beforeMigrate.sql");
276        assert_eq!(hooks[1].script_name, "beforeMigrate__A_first.sql");
277        assert_eq!(hooks[2].script_name, "beforeMigrate__B_second.sql");
278
279        let _ = fs::remove_dir_all(&dir);
280    }
281
282    #[test]
283    fn test_load_config_hooks() {
284        let dir = create_temp_dir("config");
285        let hook_file = dir.join("pre.sql");
286        fs::write(&hook_file, "SET work_mem = '256MB';").unwrap();
287
288        let config = HooksConfig {
289            before_migrate: vec![hook_file],
290            after_migrate: vec![],
291            before_each_migrate: vec![],
292            after_each_migrate: vec![],
293        };
294
295        let hooks = load_config_hooks(&config).unwrap();
296        assert_eq!(hooks.len(), 1);
297        assert_eq!(hooks[0].hook_type, HookType::BeforeMigrate);
298        assert_eq!(hooks[0].sql, "SET work_mem = '256MB';");
299
300        let _ = fs::remove_dir_all(&dir);
301    }
302
303    #[test]
304    fn test_load_config_hooks_missing_file() {
305        let config = HooksConfig {
306            before_migrate: vec![PathBuf::from("/nonexistent/hook.sql")],
307            after_migrate: vec![],
308            before_each_migrate: vec![],
309            after_each_migrate: vec![],
310        };
311
312        assert!(load_config_hooks(&config).is_err());
313    }
314}