1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4
5use tokio_postgres::Client;
6
7use crate::config::{HooksConfig, WaypointConfig};
8use crate::db;
9use crate::error::{Result, WaypointError};
10use crate::placeholder::replace_placeholders;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum HookType {
15 BeforeMigrate,
16 AfterMigrate,
17 BeforeEachMigrate,
18 AfterEachMigrate,
19}
20
21impl fmt::Display for HookType {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 HookType::BeforeMigrate => write!(f, "beforeMigrate"),
25 HookType::AfterMigrate => write!(f, "afterMigrate"),
26 HookType::BeforeEachMigrate => write!(f, "beforeEachMigrate"),
27 HookType::AfterEachMigrate => write!(f, "afterEachMigrate"),
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct ResolvedHook {
35 pub hook_type: HookType,
36 pub script_name: String,
37 pub sql: String,
38}
39
40type HookPrefixEntry = (&'static str, fn() -> HookType);
42const HOOK_PREFIXES: &[HookPrefixEntry] = &[
43 ("beforeEachMigrate", || HookType::BeforeEachMigrate),
44 ("afterEachMigrate", || HookType::AfterEachMigrate),
45 ("beforeMigrate", || HookType::BeforeMigrate),
46 ("afterMigrate", || HookType::AfterMigrate),
47];
48
49pub fn is_hook_file(filename: &str) -> bool {
51 HOOK_PREFIXES
52 .iter()
53 .any(|(prefix, _)| filename.starts_with(prefix) && filename.ends_with(".sql"))
54}
55
56pub fn scan_hooks(locations: &[PathBuf]) -> Result<Vec<ResolvedHook>> {
66 let mut hooks = Vec::new();
67
68 for location in locations {
69 if !location.exists() {
70 continue;
71 }
72
73 let entries = std::fs::read_dir(location).map_err(|e| {
74 WaypointError::IoError(std::io::Error::new(
75 e.kind(),
76 format!(
77 "Failed to read hook directory '{}': {}",
78 location.display(),
79 e
80 ),
81 ))
82 })?;
83
84 let mut files: Vec<_> = entries
85 .filter_map(|e| e.ok())
86 .filter(|e| e.path().is_file())
87 .collect();
88
89 files.sort_by_key(|e| e.file_name());
91
92 for entry in files {
93 let path = entry.path();
94 let filename = match path.file_name().and_then(|n| n.to_str()) {
95 Some(name) => name.to_string(),
96 None => continue,
97 };
98
99 if !filename.ends_with(".sql") {
100 continue;
101 }
102
103 for (prefix, type_fn) in HOOK_PREFIXES {
105 if filename.starts_with(prefix) {
106 let rest = &filename[prefix.len()..filename.len() - 4]; if rest.is_empty() || rest.starts_with("__") {
109 let sql = std::fs::read_to_string(&path)?;
110 hooks.push(ResolvedHook {
111 hook_type: type_fn(),
112 script_name: filename.clone(),
113 sql,
114 });
115 break;
116 }
117 }
118 }
119 }
120 }
121
122 hooks.sort_by(|a, b| {
124 a.hook_type
125 .to_string()
126 .cmp(&b.hook_type.to_string())
127 .then_with(|| a.script_name.cmp(&b.script_name))
128 });
129
130 Ok(hooks)
131}
132
133pub fn load_config_hooks(config: &HooksConfig) -> Result<Vec<ResolvedHook>> {
135 let mut hooks = Vec::new();
136
137 let sections: &[(HookType, &[PathBuf])] = &[
138 (HookType::BeforeMigrate, &config.before_migrate),
139 (HookType::AfterMigrate, &config.after_migrate),
140 (HookType::BeforeEachMigrate, &config.before_each_migrate),
141 (HookType::AfterEachMigrate, &config.after_each_migrate),
142 ];
143
144 for (hook_type, paths) in sections {
145 for path in *paths {
146 let sql = std::fs::read_to_string(path).map_err(|e| {
147 WaypointError::IoError(std::io::Error::new(
148 e.kind(),
149 format!("Failed to read hook file '{}': {}", path.display(), e),
150 ))
151 })?;
152
153 let script_name = path
154 .file_name()
155 .and_then(|n| n.to_str())
156 .unwrap_or_else(|| path.to_str().unwrap_or("unknown"))
157 .to_string();
158
159 hooks.push(ResolvedHook {
160 hook_type: hook_type.clone(),
161 script_name,
162 sql,
163 });
164 }
165 }
166
167 Ok(hooks)
168}
169
170pub async fn run_hooks(
174 client: &Client,
175 _config: &WaypointConfig,
176 hooks: &[ResolvedHook],
177 phase: &HookType,
178 placeholders: &HashMap<String, String>,
179) -> Result<(usize, i32)> {
180 let mut total_ms = 0;
181 let mut count = 0;
182
183 for hook in hooks.iter().filter(|h| &h.hook_type == phase) {
184 tracing::info!("Running {} hook: {}", phase, hook.script_name);
185
186 let sql = replace_placeholders(&hook.sql, placeholders)?;
187
188 match db::execute_in_transaction(client, &sql).await {
189 Ok(exec_time) => {
190 total_ms += exec_time;
191 count += 1;
192 }
193 Err(e) => {
194 let reason = match &e {
195 WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
196 other => other.to_string(),
197 };
198 return Err(WaypointError::HookFailed {
199 phase: phase.to_string(),
200 script: hook.script_name.clone(),
201 reason,
202 });
203 }
204 }
205 }
206
207 Ok((count, total_ms))
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use std::fs;
214
215 fn create_temp_dir(name: &str) -> PathBuf {
216 let dir = std::env::temp_dir().join(format!("waypoint_hooks_test_{}", name));
217 let _ = fs::remove_dir_all(&dir);
218 fs::create_dir_all(&dir).unwrap();
219 dir
220 }
221
222 #[test]
223 fn test_is_hook_file() {
224 assert!(is_hook_file("beforeMigrate.sql"));
225 assert!(is_hook_file("afterMigrate.sql"));
226 assert!(is_hook_file("beforeEachMigrate.sql"));
227 assert!(is_hook_file("afterEachMigrate.sql"));
228 assert!(is_hook_file("beforeMigrate__Disable_triggers.sql"));
229 assert!(is_hook_file("afterMigrate__Refresh_views.sql"));
230
231 assert!(!is_hook_file("V1__Create_table.sql"));
232 assert!(!is_hook_file("R__Create_view.sql"));
233 assert!(!is_hook_file("beforeMigrate.txt"));
234 assert!(!is_hook_file("random.sql"));
235 }
236
237 #[test]
238 fn test_scan_hooks_finds_callback_files() {
239 let dir = create_temp_dir("scan");
240 fs::write(dir.join("beforeMigrate.sql"), "SELECT 1;").unwrap();
241 fs::write(dir.join("afterMigrate__Refresh_views.sql"), "SELECT 2;").unwrap();
242 fs::write(dir.join("V1__Create_table.sql"), "CREATE TABLE t(id INT);").unwrap();
243 fs::write(dir.join("R__Create_view.sql"), "CREATE VIEW v AS SELECT 1;").unwrap();
244
245 let hooks = scan_hooks(&[dir.clone()]).unwrap();
246
247 assert_eq!(hooks.len(), 2);
248
249 let before: Vec<_> = hooks
250 .iter()
251 .filter(|h| h.hook_type == HookType::BeforeMigrate)
252 .collect();
253 let after: Vec<_> = hooks
254 .iter()
255 .filter(|h| h.hook_type == HookType::AfterMigrate)
256 .collect();
257 assert_eq!(before.len(), 1);
258 assert_eq!(before[0].script_name, "beforeMigrate.sql");
259 assert_eq!(after.len(), 1);
260 assert_eq!(after[0].script_name, "afterMigrate__Refresh_views.sql");
261
262 let _ = fs::remove_dir_all(&dir);
263 }
264
265 #[test]
266 fn test_scan_hooks_multiple_sorted_alphabetically() {
267 let dir = create_temp_dir("multi");
268 fs::write(dir.join("beforeMigrate__B_second.sql"), "SELECT 2;").unwrap();
269 fs::write(dir.join("beforeMigrate__A_first.sql"), "SELECT 1;").unwrap();
270 fs::write(dir.join("beforeMigrate.sql"), "SELECT 0;").unwrap();
271
272 let hooks = scan_hooks(&[dir.clone()]).unwrap();
273
274 assert_eq!(hooks.len(), 3);
275 assert_eq!(hooks[0].script_name, "beforeMigrate.sql");
276 assert_eq!(hooks[1].script_name, "beforeMigrate__A_first.sql");
277 assert_eq!(hooks[2].script_name, "beforeMigrate__B_second.sql");
278
279 let _ = fs::remove_dir_all(&dir);
280 }
281
282 #[test]
283 fn test_load_config_hooks() {
284 let dir = create_temp_dir("config");
285 let hook_file = dir.join("pre.sql");
286 fs::write(&hook_file, "SET work_mem = '256MB';").unwrap();
287
288 let config = HooksConfig {
289 before_migrate: vec![hook_file],
290 after_migrate: vec![],
291 before_each_migrate: vec![],
292 after_each_migrate: vec![],
293 };
294
295 let hooks = load_config_hooks(&config).unwrap();
296 assert_eq!(hooks.len(), 1);
297 assert_eq!(hooks[0].hook_type, HookType::BeforeMigrate);
298 assert_eq!(hooks[0].sql, "SET work_mem = '256MB';");
299
300 let _ = fs::remove_dir_all(&dir);
301 }
302
303 #[test]
304 fn test_load_config_hooks_missing_file() {
305 let config = HooksConfig {
306 before_migrate: vec![PathBuf::from("/nonexistent/hook.sql")],
307 after_migrate: vec![],
308 before_each_migrate: vec![],
309 after_each_migrate: vec![],
310 };
311
312 assert!(load_config_hooks(&config).is_err());
313 }
314}