Skip to main content

waypoint_core/commands/
migrate.rs

1use std::collections::HashMap;
2
3use serde::Serialize;
4use tokio_postgres::Client;
5
6use crate::config::WaypointConfig;
7use crate::db;
8use crate::error::{Result, WaypointError};
9use crate::history;
10use crate::hooks::{self, HookType, ResolvedHook};
11use crate::migration::{scan_migrations, MigrationVersion, ResolvedMigration};
12use crate::placeholder::{build_placeholders, replace_placeholders};
13
14/// Report returned after a migrate operation.
15#[derive(Debug, Serialize)]
16pub struct MigrateReport {
17    pub migrations_applied: usize,
18    pub total_time_ms: i32,
19    pub details: Vec<MigrateDetail>,
20    pub hooks_executed: usize,
21    pub hooks_time_ms: i32,
22}
23
24#[derive(Debug, Serialize)]
25pub struct MigrateDetail {
26    pub version: Option<String>,
27    pub description: String,
28    pub script: String,
29    pub execution_time_ms: i32,
30}
31
32/// Execute the migrate command.
33pub async fn execute(
34    client: &Client,
35    config: &WaypointConfig,
36    target_version: Option<&str>,
37) -> Result<MigrateReport> {
38    let table = &config.migrations.table;
39
40    // Acquire advisory lock
41    db::acquire_advisory_lock(client, table).await?;
42
43    let result = run_migrate(client, config, target_version).await;
44
45    // Always release the advisory lock
46    if let Err(e) = db::release_advisory_lock(client, table).await {
47        tracing::warn!(error = %e, "Failed to release advisory lock");
48    }
49
50    match &result {
51        Ok(report) => {
52            tracing::info!(
53                migrations_applied = report.migrations_applied,
54                total_time_ms = report.total_time_ms,
55                hooks_executed = report.hooks_executed,
56                "Migrate completed"
57            );
58        }
59        Err(e) => {
60            tracing::error!(error = %e, "Migrate failed");
61        }
62    }
63
64    result
65}
66
67async fn run_migrate(
68    client: &Client,
69    config: &WaypointConfig,
70    target_version: Option<&str>,
71) -> Result<MigrateReport> {
72    let schema = &config.migrations.schema;
73    let table = &config.migrations.table;
74
75    // Create history table if not exists
76    history::create_history_table(client, schema, table).await?;
77
78    // Validate on migrate if enabled
79    if config.migrations.validate_on_migrate {
80        if let Err(e) = super::validate::execute(client, config).await {
81            // Only fail on actual validation errors, not if there's nothing to validate
82            match &e {
83                WaypointError::ValidationFailed(_) => return Err(e),
84                _ => {
85                    tracing::debug!("Validation skipped: {}", e);
86                }
87            }
88        }
89    }
90
91    // Scan migration files
92    let resolved = scan_migrations(&config.migrations.locations)?;
93
94    // Scan and load hooks
95    let mut all_hooks: Vec<ResolvedHook> = hooks::scan_hooks(&config.migrations.locations)?;
96    let config_hooks = hooks::load_config_hooks(&config.hooks)?;
97    all_hooks.extend(config_hooks);
98
99    // Get applied migrations
100    let applied = history::get_applied_migrations(client, schema, table).await?;
101
102    // Get database user info for placeholders
103    let db_user = db::get_current_user(client)
104        .await
105        .unwrap_or_else(|_| "unknown".to_string());
106    let db_name = db::get_current_database(client)
107        .await
108        .unwrap_or_else(|_| "unknown".to_string());
109    let installed_by = config
110        .migrations
111        .installed_by
112        .as_deref()
113        .unwrap_or(&db_user);
114
115    // Parse target version if provided
116    let target = target_version.map(MigrationVersion::parse).transpose()?;
117
118    // Find the baseline version if any
119    let baseline_version = applied
120        .iter()
121        .find(|a| a.migration_type == "BASELINE")
122        .and_then(|a| a.version.as_ref())
123        .map(|v| MigrationVersion::parse(v))
124        .transpose()?;
125
126    // Find highest applied versioned migration (version presence, not type string,
127    // for Flyway compatibility)
128    let highest_applied = applied
129        .iter()
130        .filter(|a| a.success && a.version.is_some())
131        .filter_map(|a| a.version.as_ref())
132        .filter_map(|v| MigrationVersion::parse(v).ok())
133        .max();
134
135    // Build set of applied versions and scripts for quick lookup
136    let applied_versions: HashMap<String, &crate::history::AppliedMigration> = applied
137        .iter()
138        .filter(|a| a.success)
139        .filter_map(|a| a.version.as_ref().map(|v| (v.clone(), a)))
140        .collect();
141
142    let applied_scripts: HashMap<String, &crate::history::AppliedMigration> = applied
143        .iter()
144        .filter(|a| a.success && a.version.is_none())
145        .map(|a| (a.script.clone(), a))
146        .collect();
147
148    let mut report = MigrateReport {
149        migrations_applied: 0,
150        total_time_ms: 0,
151        details: Vec::new(),
152        hooks_executed: 0,
153        hooks_time_ms: 0,
154    };
155
156    // ── beforeMigrate hooks ──
157    let before_placeholders = build_placeholders(
158        &config.placeholders,
159        schema,
160        &db_user,
161        &db_name,
162        "beforeMigrate",
163    );
164    let (count, ms) = hooks::run_hooks(
165        client,
166        config,
167        &all_hooks,
168        &HookType::BeforeMigrate,
169        &before_placeholders,
170    )
171    .await?;
172    report.hooks_executed += count;
173    report.hooks_time_ms += ms;
174
175    // ── Apply versioned migrations ──
176    let versioned: Vec<&ResolvedMigration> = resolved.iter().filter(|m| m.is_versioned()).collect();
177
178    for migration in &versioned {
179        let version = migration.version().unwrap();
180
181        // Skip if already applied successfully
182        if applied_versions.contains_key(&version.raw) {
183            continue;
184        }
185
186        // Skip if below baseline
187        if let Some(ref bv) = baseline_version {
188            if version <= bv {
189                tracing::debug!("Skipping {} (below baseline)", migration.script);
190                continue;
191            }
192        }
193
194        // Check target version
195        if let Some(ref tv) = target {
196            if version > tv {
197                tracing::debug!("Skipping {} (above target {})", migration.script, tv);
198                break;
199            }
200        }
201
202        // Check out-of-order
203        if !config.migrations.out_of_order {
204            if let Some(ref highest) = highest_applied {
205                if version < highest {
206                    return Err(WaypointError::OutOfOrder {
207                        version: version.raw.clone(),
208                        highest: highest.raw.clone(),
209                    });
210                }
211            }
212        }
213
214        // beforeEachMigrate hooks
215        let each_placeholders = build_placeholders(
216            &config.placeholders,
217            schema,
218            &db_user,
219            &db_name,
220            &migration.script,
221        );
222        let (count, ms) = hooks::run_hooks(
223            client,
224            config,
225            &all_hooks,
226            &HookType::BeforeEachMigrate,
227            &each_placeholders,
228        )
229        .await?;
230        report.hooks_executed += count;
231        report.hooks_time_ms += ms;
232
233        // Apply migration
234        let exec_time = apply_migration(
235            client,
236            config,
237            migration,
238            schema,
239            table,
240            installed_by,
241            &db_user,
242            &db_name,
243        )
244        .await?;
245
246        // afterEachMigrate hooks
247        let (count, ms) = hooks::run_hooks(
248            client,
249            config,
250            &all_hooks,
251            &HookType::AfterEachMigrate,
252            &each_placeholders,
253        )
254        .await?;
255        report.hooks_executed += count;
256        report.hooks_time_ms += ms;
257
258        report.migrations_applied += 1;
259        report.total_time_ms += exec_time;
260        report.details.push(MigrateDetail {
261            version: Some(version.raw.clone()),
262            description: migration.description.clone(),
263            script: migration.script.clone(),
264            execution_time_ms: exec_time,
265        });
266    }
267
268    // ── Apply repeatable migrations ──
269    let repeatables: Vec<&ResolvedMigration> =
270        resolved.iter().filter(|m| !m.is_versioned()).collect();
271
272    for migration in &repeatables {
273        // Check if already applied with same checksum
274        if let Some(applied_entry) = applied_scripts.get(&migration.script) {
275            if applied_entry.checksum == Some(migration.checksum) {
276                continue; // Unchanged, skip
277            }
278            // Checksum differs — re-apply (outdated)
279            tracing::info!(migration = %migration.script, "Re-applying changed repeatable migration");
280        }
281
282        // beforeEachMigrate hooks
283        let each_placeholders = build_placeholders(
284            &config.placeholders,
285            schema,
286            &db_user,
287            &db_name,
288            &migration.script,
289        );
290        let (count, ms) = hooks::run_hooks(
291            client,
292            config,
293            &all_hooks,
294            &HookType::BeforeEachMigrate,
295            &each_placeholders,
296        )
297        .await?;
298        report.hooks_executed += count;
299        report.hooks_time_ms += ms;
300
301        let exec_time = apply_migration(
302            client,
303            config,
304            migration,
305            schema,
306            table,
307            installed_by,
308            &db_user,
309            &db_name,
310        )
311        .await?;
312
313        // afterEachMigrate hooks
314        let (count, ms) = hooks::run_hooks(
315            client,
316            config,
317            &all_hooks,
318            &HookType::AfterEachMigrate,
319            &each_placeholders,
320        )
321        .await?;
322        report.hooks_executed += count;
323        report.hooks_time_ms += ms;
324
325        report.migrations_applied += 1;
326        report.total_time_ms += exec_time;
327        report.details.push(MigrateDetail {
328            version: None,
329            description: migration.description.clone(),
330            script: migration.script.clone(),
331            execution_time_ms: exec_time,
332        });
333    }
334
335    // ── afterMigrate hooks ──
336    let after_placeholders = build_placeholders(
337        &config.placeholders,
338        schema,
339        &db_user,
340        &db_name,
341        "afterMigrate",
342    );
343    let (count, ms) = hooks::run_hooks(
344        client,
345        config,
346        &all_hooks,
347        &HookType::AfterMigrate,
348        &after_placeholders,
349    )
350    .await?;
351    report.hooks_executed += count;
352    report.hooks_time_ms += ms;
353
354    Ok(report)
355}
356
357#[allow(clippy::too_many_arguments)]
358async fn apply_migration(
359    client: &Client,
360    config: &WaypointConfig,
361    migration: &ResolvedMigration,
362    schema: &str,
363    table: &str,
364    installed_by: &str,
365    db_user: &str,
366    db_name: &str,
367) -> Result<i32> {
368    tracing::info!(migration = %migration.script, schema = %schema, "Applying migration");
369
370    // Build placeholders
371    let placeholders = build_placeholders(
372        &config.placeholders,
373        schema,
374        db_user,
375        db_name,
376        &migration.script,
377    );
378
379    // Replace placeholders in SQL
380    let sql = replace_placeholders(&migration.sql, &placeholders)?;
381
382    let version_str = migration.version().map(|v| v.raw.as_str());
383    let type_str = migration.migration_type().to_string();
384
385    // Execute in transaction
386    match db::execute_in_transaction(client, &sql).await {
387        Ok(exec_time) => {
388            // Record success (rank is assigned atomically in the INSERT)
389            history::insert_applied_migration(
390                client,
391                schema,
392                table,
393                version_str,
394                &migration.description,
395                &type_str,
396                &migration.script,
397                Some(migration.checksum),
398                installed_by,
399                exec_time,
400                true,
401            )
402            .await?;
403
404            Ok(exec_time)
405        }
406        Err(e) => {
407            // Record failure — we try to insert the failure record, but don't fail if that also fails
408            if let Err(record_err) = history::insert_applied_migration(
409                client,
410                schema,
411                table,
412                version_str,
413                &migration.description,
414                &type_str,
415                &migration.script,
416                Some(migration.checksum),
417                installed_by,
418                0,
419                false,
420            )
421            .await
422            {
423                tracing::warn!(script = %migration.script, error = %record_err, "Failed to record migration failure in history table");
424            }
425
426            // Extract detailed error message
427            let reason = match &e {
428                WaypointError::DatabaseError(db_err) => crate::error::format_db_error(db_err),
429                other => other.to_string(),
430            };
431            tracing::error!(script = %migration.script, reason = %reason, "Migration failed");
432            Err(WaypointError::MigrationFailed {
433                script: migration.script.clone(),
434                reason,
435            })
436        }
437    }
438}