1use std::collections::HashMap;
4use std::fmt;
5use std::path::PathBuf;
6
7use tokio_postgres::Client;
8
9use crate::config::HooksConfig;
10use crate::db;
11use crate::error::{Result, WaypointError};
12use crate::placeholder::replace_placeholders;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum HookType {
17 BeforeMigrate,
19 AfterMigrate,
21 BeforeEachMigrate,
23 AfterEachMigrate,
25}
26
27impl fmt::Display for HookType {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match self {
30 HookType::BeforeMigrate => write!(f, "beforeMigrate"),
31 HookType::AfterMigrate => write!(f, "afterMigrate"),
32 HookType::BeforeEachMigrate => write!(f, "beforeEachMigrate"),
33 HookType::AfterEachMigrate => write!(f, "afterEachMigrate"),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct ResolvedHook {
41 pub hook_type: HookType,
43 pub script_name: String,
45 pub sql: String,
47}
48
49type HookPrefixEntry = (&'static str, fn() -> HookType);
51const HOOK_PREFIXES: &[HookPrefixEntry] = &[
52 ("beforeEachMigrate", || HookType::BeforeEachMigrate),
53 ("afterEachMigrate", || HookType::AfterEachMigrate),
54 ("beforeMigrate", || HookType::BeforeMigrate),
55 ("afterMigrate", || HookType::AfterMigrate),
56];
57
58pub fn is_hook_file(filename: &str) -> bool {
60 HOOK_PREFIXES
61 .iter()
62 .any(|(prefix, _)| filename.starts_with(prefix) && filename.ends_with(".sql"))
63}
64
65pub fn scan_hooks(locations: &[PathBuf]) -> Result<Vec<ResolvedHook>> {
75 let mut hooks = Vec::new();
76
77 for location in locations {
78 if !location.exists() {
79 continue;
80 }
81
82 let entries = std::fs::read_dir(location).map_err(|e| {
83 WaypointError::IoError(std::io::Error::new(
84 e.kind(),
85 format!(
86 "Failed to read hook directory '{}': {}",
87 location.display(),
88 e
89 ),
90 ))
91 })?;
92
93 let mut files: Vec<_> = entries
94 .filter_map(|e| e.ok())
95 .filter(|e| e.path().is_file())
96 .collect();
97
98 files.sort_by_key(|e| e.file_name());
100
101 for entry in files {
102 let path = entry.path();
103 let filename = match path.file_name().and_then(|n| n.to_str()) {
104 Some(name) => name.to_string(),
105 None => continue,
106 };
107
108 if !filename.ends_with(".sql") {
109 continue;
110 }
111
112 for (prefix, type_fn) in HOOK_PREFIXES {
114 if filename.starts_with(prefix) {
115 let rest = &filename[prefix.len()..filename.len() - 4]; if rest.is_empty() || rest.starts_with("__") {
118 let sql = std::fs::read_to_string(&path)?;
119 hooks.push(ResolvedHook {
120 hook_type: type_fn(),
121 script_name: filename.clone(),
122 sql,
123 });
124 break;
125 }
126 }
127 }
128 }
129 }
130
131 hooks.sort_by(|a, b| {
133 a.hook_type
134 .to_string()
135 .cmp(&b.hook_type.to_string())
136 .then_with(|| a.script_name.cmp(&b.script_name))
137 });
138
139 Ok(hooks)
140}
141
142pub fn load_config_hooks(config: &HooksConfig) -> Result<Vec<ResolvedHook>> {
144 let mut hooks = Vec::new();
145
146 let sections: &[(HookType, &[PathBuf])] = &[
147 (HookType::BeforeMigrate, &config.before_migrate),
148 (HookType::AfterMigrate, &config.after_migrate),
149 (HookType::BeforeEachMigrate, &config.before_each_migrate),
150 (HookType::AfterEachMigrate, &config.after_each_migrate),
151 ];
152
153 for (hook_type, paths) in sections {
154 for path in *paths {
155 let sql = std::fs::read_to_string(path).map_err(|e| {
156 WaypointError::IoError(std::io::Error::new(
157 e.kind(),
158 format!("Failed to read hook file '{}': {}", path.display(), e),
159 ))
160 })?;
161
162 let script_name = path
163 .file_name()
164 .and_then(|n| n.to_str())
165 .unwrap_or_else(|| path.to_str().unwrap_or("unknown"))
166 .to_string();
167
168 hooks.push(ResolvedHook {
169 hook_type: hook_type.clone(),
170 script_name,
171 sql,
172 });
173 }
174 }
175
176 Ok(hooks)
177}
178
179pub async fn run_hooks(
183 client: &Client,
184 hooks: &[ResolvedHook],
185 phase: &HookType,
186 placeholders: &HashMap<String, String>,
187) -> Result<(usize, i32)> {
188 let mut total_ms = 0;
189 let mut count = 0;
190
191 for hook in hooks.iter().filter(|h| &h.hook_type == phase) {
192 log::info!("Running {} hook: {}", phase, hook.script_name);
193
194 let sql = replace_placeholders(&hook.sql, placeholders)?;
195
196 match db::execute_in_transaction(client, &sql).await {
197 Ok(exec_time) => {
198 total_ms += exec_time;
199 count += 1;
200 }
201 Err(e) => {
202 let reason = match &e {
203 WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
204 other => other.to_string(),
205 };
206 return Err(WaypointError::HookFailed {
207 phase: phase.to_string(),
208 script: hook.script_name.clone(),
209 reason,
210 });
211 }
212 }
213 }
214
215 Ok((count, total_ms))
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use std::fs;
222
223 fn create_temp_dir(name: &str) -> PathBuf {
224 let dir = std::env::temp_dir().join(format!("waypoint_hooks_test_{}", name));
225 let _ = fs::remove_dir_all(&dir);
226 fs::create_dir_all(&dir).unwrap();
227 dir
228 }
229
230 #[test]
231 fn test_is_hook_file() {
232 assert!(is_hook_file("beforeMigrate.sql"));
233 assert!(is_hook_file("afterMigrate.sql"));
234 assert!(is_hook_file("beforeEachMigrate.sql"));
235 assert!(is_hook_file("afterEachMigrate.sql"));
236 assert!(is_hook_file("beforeMigrate__Disable_triggers.sql"));
237 assert!(is_hook_file("afterMigrate__Refresh_views.sql"));
238
239 assert!(!is_hook_file("V1__Create_table.sql"));
240 assert!(!is_hook_file("R__Create_view.sql"));
241 assert!(!is_hook_file("beforeMigrate.txt"));
242 assert!(!is_hook_file("random.sql"));
243 }
244
245 #[test]
246 fn test_scan_hooks_finds_callback_files() {
247 let dir = create_temp_dir("scan");
248 fs::write(dir.join("beforeMigrate.sql"), "SELECT 1;").unwrap();
249 fs::write(dir.join("afterMigrate__Refresh_views.sql"), "SELECT 2;").unwrap();
250 fs::write(dir.join("V1__Create_table.sql"), "CREATE TABLE t(id INT);").unwrap();
251 fs::write(dir.join("R__Create_view.sql"), "CREATE VIEW v AS SELECT 1;").unwrap();
252
253 let hooks = scan_hooks(&[dir.clone()]).unwrap();
254
255 assert_eq!(hooks.len(), 2);
256
257 let before: Vec<_> = hooks
258 .iter()
259 .filter(|h| h.hook_type == HookType::BeforeMigrate)
260 .collect();
261 let after: Vec<_> = hooks
262 .iter()
263 .filter(|h| h.hook_type == HookType::AfterMigrate)
264 .collect();
265 assert_eq!(before.len(), 1);
266 assert_eq!(before[0].script_name, "beforeMigrate.sql");
267 assert_eq!(after.len(), 1);
268 assert_eq!(after[0].script_name, "afterMigrate__Refresh_views.sql");
269
270 let _ = fs::remove_dir_all(&dir);
271 }
272
273 #[test]
274 fn test_scan_hooks_multiple_sorted_alphabetically() {
275 let dir = create_temp_dir("multi");
276 fs::write(dir.join("beforeMigrate__B_second.sql"), "SELECT 2;").unwrap();
277 fs::write(dir.join("beforeMigrate__A_first.sql"), "SELECT 1;").unwrap();
278 fs::write(dir.join("beforeMigrate.sql"), "SELECT 0;").unwrap();
279
280 let hooks = scan_hooks(&[dir.clone()]).unwrap();
281
282 assert_eq!(hooks.len(), 3);
283 assert_eq!(hooks[0].script_name, "beforeMigrate.sql");
284 assert_eq!(hooks[1].script_name, "beforeMigrate__A_first.sql");
285 assert_eq!(hooks[2].script_name, "beforeMigrate__B_second.sql");
286
287 let _ = fs::remove_dir_all(&dir);
288 }
289
290 #[test]
291 fn test_load_config_hooks() {
292 let dir = create_temp_dir("config");
293 let hook_file = dir.join("pre.sql");
294 fs::write(&hook_file, "SET work_mem = '256MB';").unwrap();
295
296 let config = HooksConfig {
297 before_migrate: vec![hook_file],
298 after_migrate: vec![],
299 before_each_migrate: vec![],
300 after_each_migrate: vec![],
301 };
302
303 let hooks = load_config_hooks(&config).unwrap();
304 assert_eq!(hooks.len(), 1);
305 assert_eq!(hooks[0].hook_type, HookType::BeforeMigrate);
306 assert_eq!(hooks[0].sql, "SET work_mem = '256MB';");
307
308 let _ = fs::remove_dir_all(&dir);
309 }
310
311 #[test]
312 fn test_load_config_hooks_missing_file() {
313 let config = HooksConfig {
314 before_migrate: vec![PathBuf::from("/nonexistent/hook.sql")],
315 after_migrate: vec![],
316 before_each_migrate: vec![],
317 after_each_migrate: vec![],
318 };
319
320 assert!(load_config_hooks(&config).is_err());
321 }
322}