1use std::collections::HashMap;
4use std::fmt;
5use std::path::PathBuf;
6
7use tokio_postgres::Client;
8
9use crate::config::{HooksConfig, WaypointConfig};
10use crate::db;
11use crate::error::{Result, WaypointError};
12use crate::placeholder::replace_placeholders;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum HookType {
17 BeforeMigrate,
18 AfterMigrate,
19 BeforeEachMigrate,
20 AfterEachMigrate,
21}
22
23impl fmt::Display for HookType {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 HookType::BeforeMigrate => write!(f, "beforeMigrate"),
27 HookType::AfterMigrate => write!(f, "afterMigrate"),
28 HookType::BeforeEachMigrate => write!(f, "beforeEachMigrate"),
29 HookType::AfterEachMigrate => write!(f, "afterEachMigrate"),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct ResolvedHook {
37 pub hook_type: HookType,
38 pub script_name: String,
39 pub sql: String,
40}
41
42type HookPrefixEntry = (&'static str, fn() -> HookType);
44const HOOK_PREFIXES: &[HookPrefixEntry] = &[
45 ("beforeEachMigrate", || HookType::BeforeEachMigrate),
46 ("afterEachMigrate", || HookType::AfterEachMigrate),
47 ("beforeMigrate", || HookType::BeforeMigrate),
48 ("afterMigrate", || HookType::AfterMigrate),
49];
50
51pub fn is_hook_file(filename: &str) -> bool {
53 HOOK_PREFIXES
54 .iter()
55 .any(|(prefix, _)| filename.starts_with(prefix) && filename.ends_with(".sql"))
56}
57
58pub fn scan_hooks(locations: &[PathBuf]) -> Result<Vec<ResolvedHook>> {
68 let mut hooks = Vec::new();
69
70 for location in locations {
71 if !location.exists() {
72 continue;
73 }
74
75 let entries = std::fs::read_dir(location).map_err(|e| {
76 WaypointError::IoError(std::io::Error::new(
77 e.kind(),
78 format!(
79 "Failed to read hook directory '{}': {}",
80 location.display(),
81 e
82 ),
83 ))
84 })?;
85
86 let mut files: Vec<_> = entries
87 .filter_map(|e| e.ok())
88 .filter(|e| e.path().is_file())
89 .collect();
90
91 files.sort_by_key(|e| e.file_name());
93
94 for entry in files {
95 let path = entry.path();
96 let filename = match path.file_name().and_then(|n| n.to_str()) {
97 Some(name) => name.to_string(),
98 None => continue,
99 };
100
101 if !filename.ends_with(".sql") {
102 continue;
103 }
104
105 for (prefix, type_fn) in HOOK_PREFIXES {
107 if filename.starts_with(prefix) {
108 let rest = &filename[prefix.len()..filename.len() - 4]; if rest.is_empty() || rest.starts_with("__") {
111 let sql = std::fs::read_to_string(&path)?;
112 hooks.push(ResolvedHook {
113 hook_type: type_fn(),
114 script_name: filename.clone(),
115 sql,
116 });
117 break;
118 }
119 }
120 }
121 }
122 }
123
124 hooks.sort_by(|a, b| {
126 a.hook_type
127 .to_string()
128 .cmp(&b.hook_type.to_string())
129 .then_with(|| a.script_name.cmp(&b.script_name))
130 });
131
132 Ok(hooks)
133}
134
135pub fn load_config_hooks(config: &HooksConfig) -> Result<Vec<ResolvedHook>> {
137 let mut hooks = Vec::new();
138
139 let sections: &[(HookType, &[PathBuf])] = &[
140 (HookType::BeforeMigrate, &config.before_migrate),
141 (HookType::AfterMigrate, &config.after_migrate),
142 (HookType::BeforeEachMigrate, &config.before_each_migrate),
143 (HookType::AfterEachMigrate, &config.after_each_migrate),
144 ];
145
146 for (hook_type, paths) in sections {
147 for path in *paths {
148 let sql = std::fs::read_to_string(path).map_err(|e| {
149 WaypointError::IoError(std::io::Error::new(
150 e.kind(),
151 format!("Failed to read hook file '{}': {}", path.display(), e),
152 ))
153 })?;
154
155 let script_name = path
156 .file_name()
157 .and_then(|n| n.to_str())
158 .unwrap_or_else(|| path.to_str().unwrap_or("unknown"))
159 .to_string();
160
161 hooks.push(ResolvedHook {
162 hook_type: hook_type.clone(),
163 script_name,
164 sql,
165 });
166 }
167 }
168
169 Ok(hooks)
170}
171
172pub async fn run_hooks(
176 client: &Client,
177 _config: &WaypointConfig,
178 hooks: &[ResolvedHook],
179 phase: &HookType,
180 placeholders: &HashMap<String, String>,
181) -> Result<(usize, i32)> {
182 let mut total_ms = 0;
183 let mut count = 0;
184
185 for hook in hooks.iter().filter(|h| &h.hook_type == phase) {
186 tracing::info!("Running {} hook: {}", phase, hook.script_name);
187
188 let sql = replace_placeholders(&hook.sql, placeholders)?;
189
190 match db::execute_in_transaction(client, &sql).await {
191 Ok(exec_time) => {
192 total_ms += exec_time;
193 count += 1;
194 }
195 Err(e) => {
196 let reason = match &e {
197 WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
198 other => other.to_string(),
199 };
200 return Err(WaypointError::HookFailed {
201 phase: phase.to_string(),
202 script: hook.script_name.clone(),
203 reason,
204 });
205 }
206 }
207 }
208
209 Ok((count, total_ms))
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use std::fs;
216
217 fn create_temp_dir(name: &str) -> PathBuf {
218 let dir = std::env::temp_dir().join(format!("waypoint_hooks_test_{}", name));
219 let _ = fs::remove_dir_all(&dir);
220 fs::create_dir_all(&dir).unwrap();
221 dir
222 }
223
224 #[test]
225 fn test_is_hook_file() {
226 assert!(is_hook_file("beforeMigrate.sql"));
227 assert!(is_hook_file("afterMigrate.sql"));
228 assert!(is_hook_file("beforeEachMigrate.sql"));
229 assert!(is_hook_file("afterEachMigrate.sql"));
230 assert!(is_hook_file("beforeMigrate__Disable_triggers.sql"));
231 assert!(is_hook_file("afterMigrate__Refresh_views.sql"));
232
233 assert!(!is_hook_file("V1__Create_table.sql"));
234 assert!(!is_hook_file("R__Create_view.sql"));
235 assert!(!is_hook_file("beforeMigrate.txt"));
236 assert!(!is_hook_file("random.sql"));
237 }
238
239 #[test]
240 fn test_scan_hooks_finds_callback_files() {
241 let dir = create_temp_dir("scan");
242 fs::write(dir.join("beforeMigrate.sql"), "SELECT 1;").unwrap();
243 fs::write(dir.join("afterMigrate__Refresh_views.sql"), "SELECT 2;").unwrap();
244 fs::write(dir.join("V1__Create_table.sql"), "CREATE TABLE t(id INT);").unwrap();
245 fs::write(dir.join("R__Create_view.sql"), "CREATE VIEW v AS SELECT 1;").unwrap();
246
247 let hooks = scan_hooks(&[dir.clone()]).unwrap();
248
249 assert_eq!(hooks.len(), 2);
250
251 let before: Vec<_> = hooks
252 .iter()
253 .filter(|h| h.hook_type == HookType::BeforeMigrate)
254 .collect();
255 let after: Vec<_> = hooks
256 .iter()
257 .filter(|h| h.hook_type == HookType::AfterMigrate)
258 .collect();
259 assert_eq!(before.len(), 1);
260 assert_eq!(before[0].script_name, "beforeMigrate.sql");
261 assert_eq!(after.len(), 1);
262 assert_eq!(after[0].script_name, "afterMigrate__Refresh_views.sql");
263
264 let _ = fs::remove_dir_all(&dir);
265 }
266
267 #[test]
268 fn test_scan_hooks_multiple_sorted_alphabetically() {
269 let dir = create_temp_dir("multi");
270 fs::write(dir.join("beforeMigrate__B_second.sql"), "SELECT 2;").unwrap();
271 fs::write(dir.join("beforeMigrate__A_first.sql"), "SELECT 1;").unwrap();
272 fs::write(dir.join("beforeMigrate.sql"), "SELECT 0;").unwrap();
273
274 let hooks = scan_hooks(&[dir.clone()]).unwrap();
275
276 assert_eq!(hooks.len(), 3);
277 assert_eq!(hooks[0].script_name, "beforeMigrate.sql");
278 assert_eq!(hooks[1].script_name, "beforeMigrate__A_first.sql");
279 assert_eq!(hooks[2].script_name, "beforeMigrate__B_second.sql");
280
281 let _ = fs::remove_dir_all(&dir);
282 }
283
284 #[test]
285 fn test_load_config_hooks() {
286 let dir = create_temp_dir("config");
287 let hook_file = dir.join("pre.sql");
288 fs::write(&hook_file, "SET work_mem = '256MB';").unwrap();
289
290 let config = HooksConfig {
291 before_migrate: vec![hook_file],
292 after_migrate: vec![],
293 before_each_migrate: vec![],
294 after_each_migrate: vec![],
295 };
296
297 let hooks = load_config_hooks(&config).unwrap();
298 assert_eq!(hooks.len(), 1);
299 assert_eq!(hooks[0].hook_type, HookType::BeforeMigrate);
300 assert_eq!(hooks[0].sql, "SET work_mem = '256MB';");
301
302 let _ = fs::remove_dir_all(&dir);
303 }
304
305 #[test]
306 fn test_load_config_hooks_missing_file() {
307 let config = HooksConfig {
308 before_migrate: vec![PathBuf::from("/nonexistent/hook.sql")],
309 after_migrate: vec![],
310 before_each_migrate: vec![],
311 after_each_migrate: vec![],
312 };
313
314 assert!(load_config_hooks(&config).is_err());
315 }
316}