Skip to main content

waypoint_core/commands/
migrate.rs

1//! Apply pending migrations to the database.
2
3use std::collections::{HashMap, HashSet};
4
5use serde::Serialize;
6use tokio_postgres::Client;
7
8use crate::config::WaypointConfig;
9use crate::db;
10use crate::directive::MigrationDirectives;
11use crate::error::{Result, WaypointError};
12use crate::history;
13use crate::hooks::{self, HookType, ResolvedHook};
14use crate::migration::{scan_migrations, MigrationVersion, ResolvedMigration};
15use crate::placeholder::{build_placeholders, replace_placeholders};
16
17/// Check if a migration should run in the current environment.
18///
19/// Returns true if:
20/// - The migration has no env directives (runs everywhere)
21/// - No environment is configured (runs everything)
22/// - The migration's env list includes the current environment
23fn should_run_in_environment(directives: &MigrationDirectives, current_env: Option<&str>) -> bool {
24    // No env directives = runs everywhere
25    if directives.env.is_empty() {
26        return true;
27    }
28    // No environment configured = runs everything
29    let env = match current_env {
30        Some(e) => e,
31        None => return true,
32    };
33    // Check if current env matches any directive
34    directives.env.iter().any(|e| e.eq_ignore_ascii_case(env))
35}
36
37/// Report returned after a migrate operation.
38#[derive(Debug, Serialize)]
39pub struct MigrateReport {
40    /// Number of migrations that were applied in this run.
41    pub migrations_applied: usize,
42    /// Total execution time of all migrations in milliseconds.
43    pub total_time_ms: i32,
44    /// Per-migration details for each applied migration.
45    pub details: Vec<MigrateDetail>,
46    /// Number of lifecycle hooks that were executed.
47    pub hooks_executed: usize,
48    /// Total execution time of all hooks in milliseconds.
49    pub hooks_time_ms: i32,
50}
51
52/// Details of a single applied migration within a migrate run.
53#[derive(Debug, Serialize)]
54pub struct MigrateDetail {
55    /// Version string, or None for repeatable migrations.
56    pub version: Option<String>,
57    /// Human-readable description from the migration filename.
58    pub description: String,
59    /// Filename of the migration script.
60    pub script: String,
61    /// Execution time of this migration in milliseconds.
62    pub execution_time_ms: i32,
63}
64
65/// Result of evaluating require guard preconditions for a single migration.
66enum GuardAction {
67    /// All preconditions passed; proceed with the migration.
68    Continue,
69    /// A precondition failed with on_require_fail=Skip; skip this migration.
70    Skip,
71    /// A precondition failed fatally; abort with the given error.
72    Error(WaypointError),
73}
74
75/// Common state prepared by `prepare_migrate()` for both run modes.
76struct MigrateSetup<'a> {
77    /// All resolved migration files on disk.
78    resolved: Vec<ResolvedMigration>,
79    /// All hooks (from disk + config).
80    all_hooks: Vec<ResolvedHook>,
81    /// Current database user.
82    db_user: String,
83    /// Current database name.
84    db_name: String,
85    /// Who to record as the installer.
86    installed_by: String,
87    /// Parsed target version, if specified.
88    target: Option<MigrationVersion>,
89    /// Baseline version from history, if any.
90    baseline_version: Option<MigrationVersion>,
91    /// Set of effectively-applied version strings (respects undo).
92    effective_versions: HashSet<String>,
93    /// Highest effectively-applied version.
94    highest_applied: Option<MigrationVersion>,
95    /// Map of repeatable script name -> applied checksum (for checksum comparison).
96    applied_scripts: HashMap<String, Option<i32>>,
97    /// Current environment from config.
98    current_env: Option<&'a str>,
99}
100
101/// Perform all shared setup: history table creation, validation, preflight,
102/// file scanning, hooks loading, version computation.
103async fn prepare_migrate<'a>(
104    client: &Client,
105    config: &'a WaypointConfig,
106    target_version: Option<&str>,
107) -> Result<MigrateSetup<'a>> {
108    let schema = &config.migrations.schema;
109    let table = &config.migrations.table;
110
111    // Create history table if not exists
112    history::create_history_table(client, schema, table).await?;
113
114    // Validate on migrate if enabled
115    if config.migrations.validate_on_migrate {
116        if let Err(e) = super::validate::execute(client, config).await {
117            // Only fail on actual validation errors, not if there's nothing to validate
118            match &e {
119                WaypointError::ValidationFailed(_) => return Err(e),
120                _ => {
121                    log::debug!("Validation skipped: {}", e);
122                }
123            }
124        }
125    }
126
127    // Run preflight checks if enabled
128    if config.preflight.enabled {
129        let preflight_report = crate::preflight::run_preflight(client, &config.preflight).await?;
130        if !preflight_report.passed {
131            let failed_checks: Vec<String> = preflight_report
132                .checks
133                .iter()
134                .filter(|c| c.status == crate::preflight::CheckStatus::Fail)
135                .map(|c| format!("{}: {}", c.name, c.detail))
136                .collect();
137            return Err(WaypointError::PreflightFailed {
138                checks: failed_checks.join("; "),
139            });
140        }
141    }
142
143    // Scan migration files
144    let resolved = scan_migrations(&config.migrations.locations)?;
145
146    // Scan and load hooks
147    let mut all_hooks: Vec<ResolvedHook> = hooks::scan_hooks(&config.migrations.locations)?;
148    let config_hooks = hooks::load_config_hooks(&config.hooks)?;
149    all_hooks.extend(config_hooks);
150
151    // Get applied migrations
152    let applied = history::get_applied_migrations(client, schema, table).await?;
153
154    // Get database user info for placeholders
155    let db_user = db::get_current_user(client)
156        .await
157        .unwrap_or_else(|_| "unknown".to_string());
158    let db_name = db::get_current_database(client)
159        .await
160        .unwrap_or_else(|_| "unknown".to_string());
161    let installed_by = config
162        .migrations
163        .installed_by
164        .as_deref()
165        .unwrap_or(&db_user)
166        .to_string();
167
168    // Parse target version if provided
169    let target = target_version.map(MigrationVersion::parse).transpose()?;
170
171    // Find the baseline version if any
172    let baseline_version = applied
173        .iter()
174        .find(|a| a.migration_type == "BASELINE")
175        .and_then(|a| a.version.as_ref())
176        .map(|v| MigrationVersion::parse(v))
177        .transpose()?;
178
179    // Compute effective applied versions (respects undo state)
180    let effective_versions = history::effective_applied_versions(&applied);
181
182    // Find highest effectively-applied versioned migration
183    let highest_applied = effective_versions
184        .iter()
185        .filter_map(|v| MigrationVersion::parse(v).ok())
186        .max();
187
188    let applied_scripts: HashMap<String, Option<i32>> = applied
189        .iter()
190        .filter(|a| a.success && a.version.is_none())
191        .map(|a| (a.script.clone(), a.checksum))
192        .collect();
193
194    let current_env = config.migrations.environment.as_deref();
195
196    Ok(MigrateSetup {
197        resolved,
198        all_hooks,
199        db_user,
200        db_name,
201        installed_by,
202        target,
203        baseline_version,
204        effective_versions,
205        highest_applied,
206        applied_scripts,
207        current_env,
208    })
209}
210
211/// Filter resolved migrations down to pending versioned ones, applying
212/// baseline/target/out-of-order checks.
213fn filter_pending_versioned<'a>(
214    versioned: &[&'a ResolvedMigration],
215    setup: &MigrateSetup<'_>,
216    config: &WaypointConfig,
217) -> Result<Vec<&'a ResolvedMigration>> {
218    let mut pending = Vec::new();
219    for migration in versioned {
220        let version = migration.version().unwrap();
221
222        // Skip if already effectively applied (respects undo state)
223        if setup.effective_versions.contains(&version.raw) {
224            continue;
225        }
226
227        // Skip if below baseline
228        if let Some(ref bv) = setup.baseline_version {
229            if version <= bv {
230                log::debug!("Skipping {} (below baseline)", migration.script);
231                continue;
232            }
233        }
234
235        // Check target version
236        if let Some(ref tv) = setup.target {
237            if version > tv {
238                log::debug!("Skipping {} (above target {})", migration.script, tv);
239                break;
240            }
241        }
242
243        // Check out-of-order
244        if !config.migrations.out_of_order {
245            if let Some(ref highest) = setup.highest_applied {
246                if version < highest {
247                    return Err(WaypointError::OutOfOrder {
248                        version: version.raw.clone(),
249                        highest: highest.raw.clone(),
250                    });
251                }
252            }
253        }
254
255        pending.push(*migration);
256    }
257    Ok(pending)
258}
259
260/// Filter resolved migrations down to pending repeatable ones (checksum changed or new).
261fn filter_pending_repeatables<'a>(
262    repeatables: &[&'a ResolvedMigration],
263    setup: &MigrateSetup<'_>,
264) -> Vec<&'a ResolvedMigration> {
265    let mut pending = Vec::new();
266    for migration in repeatables {
267        if let Some(&applied_checksum) = setup.applied_scripts.get(&migration.script) {
268            if applied_checksum == Some(migration.checksum) {
269                continue;
270            }
271        }
272        pending.push(*migration);
273    }
274    pending
275}
276
277/// Evaluate all `-- waypoint:require` guard preconditions for a migration.
278///
279/// Returns `GuardAction::Continue` if all guards pass, `GuardAction::Skip` if
280/// the migration should be skipped (when `on_require_fail = Skip`), or
281/// `GuardAction::Error` if a fatal guard failure occurs.
282async fn evaluate_require_guards(
283    client: &Client,
284    schema: &str,
285    migration: &ResolvedMigration,
286    config: &WaypointConfig,
287) -> Result<GuardAction> {
288    if migration.directives.require.is_empty() {
289        return Ok(GuardAction::Continue);
290    }
291
292    for expr_str in &migration.directives.require {
293        match crate::guard::parse(expr_str) {
294            Ok(expr) => {
295                match crate::guard::evaluate(client, schema, &expr).await {
296                    Ok(true) => {} // Precondition met
297                    Ok(false) => {
298                        match config.guards.on_require_fail {
299                            crate::guard::OnRequireFail::Skip => {
300                                log::info!(
301                                    "Guard require failed, skipping migration; script={}, expr={}",
302                                    migration.script,
303                                    expr_str
304                                );
305                                return Ok(GuardAction::Skip);
306                            }
307                            crate::guard::OnRequireFail::Warn => {
308                                log::warn!(
309                                    "Guard require failed (continuing); script={}, expr={}",
310                                    migration.script,
311                                    expr_str
312                                );
313                                // Continue with the migration despite guard failure
314                            }
315                            crate::guard::OnRequireFail::Error => {
316                                return Ok(GuardAction::Error(WaypointError::GuardFailed {
317                                    kind: "require".to_string(),
318                                    script: migration.script.clone(),
319                                    expression: expr_str.clone(),
320                                }));
321                            }
322                        }
323                    }
324                    Err(e) => {
325                        log::warn!(
326                            "Guard evaluation error; script={}, expr={}, error={}",
327                            migration.script,
328                            expr_str,
329                            e
330                        );
331                        return Ok(GuardAction::Error(WaypointError::GuardFailed {
332                            kind: "require".to_string(),
333                            script: migration.script.clone(),
334                            expression: format!("{} (evaluation error: {})", expr_str, e),
335                        }));
336                    }
337                }
338            }
339            Err(e) => {
340                return Ok(GuardAction::Error(WaypointError::GuardFailed {
341                    kind: "require".to_string(),
342                    script: migration.script.clone(),
343                    expression: format!("{} (parse error: {})", expr_str, e),
344                }));
345            }
346        }
347    }
348    Ok(GuardAction::Continue)
349}
350
351/// Evaluate all `-- waypoint:ensure` guard postconditions for a migration.
352///
353/// Returns `Ok(())` if all postconditions pass. Returns an error if any
354/// postcondition fails or cannot be evaluated.
355async fn evaluate_ensure_guards(
356    client: &Client,
357    schema: &str,
358    migration: &ResolvedMigration,
359) -> Result<()> {
360    for expr_str in &migration.directives.ensure {
361        match crate::guard::parse(expr_str) {
362            Ok(expr) => {
363                match crate::guard::evaluate(client, schema, &expr).await {
364                    Ok(true) => {} // Postcondition met
365                    Ok(false) => {
366                        return Err(WaypointError::GuardFailed {
367                            kind: "ensure".to_string(),
368                            script: migration.script.clone(),
369                            expression: expr_str.clone(),
370                        });
371                    }
372                    Err(e) => {
373                        return Err(WaypointError::GuardFailed {
374                            kind: "ensure".to_string(),
375                            script: migration.script.clone(),
376                            expression: format!("{} (evaluation error: {})", expr_str, e),
377                        });
378                    }
379                }
380            }
381            Err(e) => {
382                return Err(WaypointError::GuardFailed {
383                    kind: "ensure".to_string(),
384                    script: migration.script.clone(),
385                    expression: format!("{} (parse error: {})", expr_str, e),
386                });
387            }
388        }
389    }
390    Ok(())
391}
392
393/// Execute the migrate command.
394pub async fn execute(
395    client: &Client,
396    config: &WaypointConfig,
397    target_version: Option<&str>,
398) -> Result<MigrateReport> {
399    execute_with_options(client, config, target_version, false).await
400}
401
402/// Execute the migrate command with additional options.
403pub async fn execute_with_options(
404    client: &Client,
405    config: &WaypointConfig,
406    target_version: Option<&str>,
407    force: bool,
408) -> Result<MigrateReport> {
409    let table = &config.migrations.table;
410
411    // Acquire advisory lock
412    db::acquire_advisory_lock(client, table).await?;
413
414    let result = if config.migrations.batch_transaction {
415        run_batch_migrate(client, config, target_version, force).await
416    } else {
417        run_migrate(client, config, target_version, force).await
418    };
419
420    // Always release the advisory lock
421    if let Err(e) = db::release_advisory_lock(client, table).await {
422        log::error!("Failed to release advisory lock: {}", e);
423    }
424
425    match &result {
426        Ok(report) => {
427            log::info!(
428                "Migrate completed; migrations_applied={}, total_time_ms={}, hooks_executed={}",
429                report.migrations_applied,
430                report.total_time_ms,
431                report.hooks_executed
432            );
433        }
434        Err(e) => {
435            log::error!("Migrate failed: {}", e);
436        }
437    }
438
439    result
440}
441
442async fn run_migrate(
443    client: &Client,
444    config: &WaypointConfig,
445    target_version: Option<&str>,
446    force_override: bool,
447) -> Result<MigrateReport> {
448    let schema = &config.migrations.schema;
449    let table = &config.migrations.table;
450
451    let setup = prepare_migrate(client, config, target_version).await?;
452
453    let mut report = MigrateReport {
454        migrations_applied: 0,
455        total_time_ms: 0,
456        details: Vec::new(),
457        hooks_executed: 0,
458        hooks_time_ms: 0,
459    };
460
461    // ── beforeMigrate hooks ──
462    let before_placeholders = build_placeholders(
463        &config.placeholders,
464        schema,
465        &setup.db_user,
466        &setup.db_name,
467        "beforeMigrate",
468    );
469    let (count, ms) = hooks::run_hooks(
470        client,
471        &setup.all_hooks,
472        &HookType::BeforeMigrate,
473        &before_placeholders,
474    )
475    .await?;
476    report.hooks_executed += count;
477    report.hooks_time_ms += ms;
478
479    // ── Apply versioned migrations ──
480    let versioned: Vec<&ResolvedMigration> = setup
481        .resolved
482        .iter()
483        .filter(|m| m.is_versioned())
484        .filter(|m| should_run_in_environment(&m.directives, setup.current_env))
485        .collect();
486
487    let pending_versioned = filter_pending_versioned(&versioned, &setup, config)?;
488
489    for migration in &pending_versioned {
490        let version = migration.version().unwrap();
491
492        // beforeEachMigrate hooks
493        let each_placeholders = build_placeholders(
494            &config.placeholders,
495            schema,
496            &setup.db_user,
497            &setup.db_name,
498            &migration.script,
499        );
500        let (count, ms) = hooks::run_hooks(
501            client,
502            &setup.all_hooks,
503            &HookType::BeforeEachMigrate,
504            &each_placeholders,
505        )
506        .await?;
507        report.hooks_executed += count;
508        report.hooks_time_ms += ms;
509
510        // ── Safety analysis (before apply) ──
511        if config.safety.enabled {
512            let safety_report = crate::safety::analyze_migration(
513                client,
514                schema,
515                &migration.sql,
516                &migration.script,
517                &config.safety,
518            )
519            .await?;
520            if safety_report.overall_verdict == crate::safety::SafetyVerdict::Danger
521                && config.safety.block_on_danger
522                && !migration.directives.safety_override
523                && !force_override
524            {
525                return Err(WaypointError::MigrationBlocked {
526                    script: migration.script.clone(),
527                    reason: safety_report.suggestions.join("; "),
528                });
529            }
530        }
531
532        // ── Guard preconditions (require) ──
533        match evaluate_require_guards(client, schema, migration, config).await? {
534            GuardAction::Continue => {}
535            GuardAction::Skip => continue,
536            GuardAction::Error(e) => return Err(e),
537        }
538
539        // ── Capture before-snapshot for auto-reversal ──
540        let before_snapshot = if config.reversals.enabled && migration.is_versioned() {
541            Some(crate::reversal::capture_before(client, schema).await?)
542        } else {
543            None
544        };
545
546        // Apply migration (hold transaction open if we need to evaluate ensure guards)
547        let has_ensure_guards = !migration.directives.ensure.is_empty();
548        let exec_time = apply_migration(
549            client,
550            config,
551            migration,
552            schema,
553            table,
554            &setup.installed_by,
555            &setup.db_user,
556            &setup.db_name,
557            has_ensure_guards,
558        )
559        .await?;
560
561        // ── Guard postconditions (ensure) — evaluated inside the open transaction ──
562        if has_ensure_guards {
563            if let Err(guard_err) = evaluate_ensure_guards(client, schema, migration).await {
564                if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
565                    log::error!(
566                        "Failed to rollback after ensure guard failure: {}",
567                        rollback_err
568                    );
569                }
570                return Err(guard_err);
571            }
572            // All ensure guards passed — commit the transaction
573            client.batch_execute("COMMIT").await?;
574        }
575
576        // ── Auto-reversal generation (after successful apply) ──
577        if let Some(ref before) = before_snapshot {
578            if let Some(ver) = migration.version() {
579                match crate::reversal::generate_reversal(
580                    client,
581                    schema,
582                    before,
583                    config.reversals.warn_data_loss,
584                )
585                .await
586                {
587                    Ok(result) => {
588                        if let Some(ref reversal_sql) = result.reversal_sql {
589                            if let Err(e) = crate::reversal::store_reversal(
590                                client,
591                                schema,
592                                table,
593                                &ver.raw,
594                                reversal_sql,
595                            )
596                            .await
597                            {
598                                log::warn!(
599                                    "Failed to store reversal SQL; version={}, error={}",
600                                    ver.raw,
601                                    e
602                                );
603                            }
604                        }
605                        for warning in &result.warnings {
606                            log::warn!("Reversal warning for {}: {}", migration.script, warning);
607                        }
608                    }
609                    Err(e) => {
610                        log::warn!(
611                            "Failed to generate reversal; script={}, error={}",
612                            migration.script,
613                            e
614                        );
615                    }
616                }
617            }
618        }
619
620        // afterEachMigrate hooks
621        let (count, ms) = hooks::run_hooks(
622            client,
623            &setup.all_hooks,
624            &HookType::AfterEachMigrate,
625            &each_placeholders,
626        )
627        .await?;
628        report.hooks_executed += count;
629        report.hooks_time_ms += ms;
630
631        report.migrations_applied += 1;
632        report.total_time_ms += exec_time;
633        report.details.push(MigrateDetail {
634            version: Some(version.raw.clone()),
635            description: migration.description.clone(),
636            script: migration.script.clone(),
637            execution_time_ms: exec_time,
638        });
639    }
640
641    // ── Apply repeatable migrations ──
642    let repeatables: Vec<&ResolvedMigration> = setup
643        .resolved
644        .iter()
645        .filter(|m| !m.is_versioned() && !m.is_undo())
646        .filter(|m| should_run_in_environment(&m.directives, setup.current_env))
647        .collect();
648
649    for migration in &repeatables {
650        // Check if already applied with same checksum
651        if let Some(&applied_checksum) = setup.applied_scripts.get(&migration.script) {
652            if applied_checksum == Some(migration.checksum) {
653                continue; // Unchanged, skip
654            }
655            // Checksum differs — re-apply (outdated)
656            log::info!(
657                "Re-applying changed repeatable migration; migration={}",
658                migration.script
659            );
660        }
661
662        // beforeEachMigrate hooks
663        let each_placeholders = build_placeholders(
664            &config.placeholders,
665            schema,
666            &setup.db_user,
667            &setup.db_name,
668            &migration.script,
669        );
670        let (count, ms) = hooks::run_hooks(
671            client,
672            &setup.all_hooks,
673            &HookType::BeforeEachMigrate,
674            &each_placeholders,
675        )
676        .await?;
677        report.hooks_executed += count;
678        report.hooks_time_ms += ms;
679
680        let exec_time = apply_migration(
681            client,
682            config,
683            migration,
684            schema,
685            table,
686            &setup.installed_by,
687            &setup.db_user,
688            &setup.db_name,
689            false,
690        )
691        .await?;
692
693        // afterEachMigrate hooks
694        let (count, ms) = hooks::run_hooks(
695            client,
696            &setup.all_hooks,
697            &HookType::AfterEachMigrate,
698            &each_placeholders,
699        )
700        .await?;
701        report.hooks_executed += count;
702        report.hooks_time_ms += ms;
703
704        report.migrations_applied += 1;
705        report.total_time_ms += exec_time;
706        report.details.push(MigrateDetail {
707            version: None,
708            description: migration.description.clone(),
709            script: migration.script.clone(),
710            execution_time_ms: exec_time,
711        });
712    }
713
714    // ── afterMigrate hooks ──
715    let after_placeholders = build_placeholders(
716        &config.placeholders,
717        schema,
718        &setup.db_user,
719        &setup.db_name,
720        "afterMigrate",
721    );
722    let (count, ms) = hooks::run_hooks(
723        client,
724        &setup.all_hooks,
725        &HookType::AfterMigrate,
726        &after_placeholders,
727    )
728    .await?;
729    report.hooks_executed += count;
730    report.hooks_time_ms += ms;
731
732    Ok(report)
733}
734
735/// Pre-compiled regexes for batch-compatibility checks.
736mod batch_regexes {
737    use std::sync::LazyLock;
738    pub static DROP_INDEX_CONCURRENT: LazyLock<regex_lite::Regex> =
739        LazyLock::new(|| regex_lite::Regex::new(r"(?i)DROP\s+INDEX\s+CONCURRENTLY").unwrap());
740    pub static CREATE_DATABASE: LazyLock<regex_lite::Regex> =
741        LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bCREATE DATABASE\b").unwrap());
742    pub static DROP_DATABASE: LazyLock<regex_lite::Regex> =
743        LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bDROP DATABASE\b").unwrap());
744    pub static VACUUM: LazyLock<regex_lite::Regex> =
745        LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bVACUUM\b").unwrap());
746    pub static CLUSTER: LazyLock<regex_lite::Regex> =
747        LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bCLUSTER\b").unwrap());
748    pub static REINDEX_CONCURRENT: LazyLock<regex_lite::Regex> =
749        LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bREINDEX\b.*\bCONCURRENTLY\b").unwrap());
750}
751
752/// Check that a migration's SQL does not contain statements that cannot run inside a transaction.
753///
754/// Returns an error if CONCURRENTLY, CREATE DATABASE, DROP DATABASE, VACUUM, CLUSTER,
755/// or REINDEX CONCURRENTLY are found. Uses pre-compiled static regexes for efficiency.
756fn validate_batch_compatible(script: &str, sql: &str) -> Result<()> {
757    let upper = sql.to_uppercase();
758
759    // Check CONCURRENTLY: verify via DDL parser first, then regex for DROP INDEX CONCURRENTLY
760    if upper.contains("CONCURRENTLY") {
761        let ops = crate::sql_parser::extract_ddl_operations(sql);
762        for op in &ops {
763            if let crate::sql_parser::DdlOperation::CreateIndex {
764                is_concurrent: true,
765                ..
766            } = op
767            {
768                return Err(WaypointError::NonTransactionalStatement {
769                    script: script.to_string(),
770                    statement: op.to_string(),
771                });
772            }
773        }
774        if batch_regexes::DROP_INDEX_CONCURRENT.is_match(sql) {
775            return Err(WaypointError::NonTransactionalStatement {
776                script: script.to_string(),
777                statement: "DROP INDEX CONCURRENTLY".to_string(),
778            });
779        }
780    }
781
782    // Check CREATE/DROP DATABASE
783    if upper.contains("CREATE DATABASE") && batch_regexes::CREATE_DATABASE.is_match(sql) {
784        return Err(WaypointError::NonTransactionalStatement {
785            script: script.to_string(),
786            statement: "CREATE DATABASE".to_string(),
787        });
788    }
789    if upper.contains("DROP DATABASE") && batch_regexes::DROP_DATABASE.is_match(sql) {
790        return Err(WaypointError::NonTransactionalStatement {
791            script: script.to_string(),
792            statement: "DROP DATABASE".to_string(),
793        });
794    }
795
796    // Check VACUUM, CLUSTER, REINDEX CONCURRENTLY
797    let checks: &[(&regex_lite::Regex, &str, &str)] = &[
798        (&batch_regexes::VACUUM, "VACUUM", "VACUUM"),
799        (&batch_regexes::CLUSTER, "CLUSTER", "CLUSTER"),
800        (
801            &batch_regexes::REINDEX_CONCURRENT,
802            "REINDEX",
803            "REINDEX CONCURRENTLY",
804        ),
805    ];
806    for &(re, fast_check, desc) in checks {
807        if upper.contains(fast_check) && re.is_match(sql) {
808            return Err(WaypointError::NonTransactionalStatement {
809                script: script.to_string(),
810                statement: desc.to_string(),
811            });
812        }
813    }
814
815    Ok(())
816}
817
818/// Run all pending migrations in a single transaction (all-or-nothing batch mode).
819async fn run_batch_migrate(
820    client: &Client,
821    config: &WaypointConfig,
822    target_version: Option<&str>,
823    force_override: bool,
824) -> Result<MigrateReport> {
825    let schema = &config.migrations.schema;
826    let table = &config.migrations.table;
827
828    let setup = prepare_migrate(client, config, target_version).await?;
829
830    let current_env = setup.current_env;
831
832    // Build the list of pending versioned migrations
833    let versioned: Vec<&ResolvedMigration> = setup
834        .resolved
835        .iter()
836        .filter(|m| m.is_versioned())
837        .filter(|m| should_run_in_environment(&m.directives, current_env))
838        .collect();
839
840    let mut pending_versioned = filter_pending_versioned(&versioned, &setup, config)?;
841
842    // Build list of pending repeatable migrations
843    let repeatables: Vec<&ResolvedMigration> = setup
844        .resolved
845        .iter()
846        .filter(|m| !m.is_versioned() && !m.is_undo())
847        .filter(|m| should_run_in_environment(&m.directives, current_env))
848        .collect();
849    let pending_repeatables = filter_pending_repeatables(&repeatables, &setup);
850
851    // Pre-validate: check all pending migrations are batch-compatible
852    let placeholders_map = build_placeholders(
853        &config.placeholders,
854        schema,
855        &setup.db_user,
856        &setup.db_name,
857        "batch_validate",
858    );
859    for migration in pending_versioned.iter().chain(pending_repeatables.iter()) {
860        let sql = replace_placeholders(&migration.sql, &placeholders_map)?;
861        validate_batch_compatible(&migration.script, &sql)?;
862    }
863
864    // Safety analysis (before batch transaction)
865    if config.safety.enabled {
866        for migration in &pending_versioned {
867            let safety_report = crate::safety::analyze_migration(
868                client,
869                schema,
870                &migration.sql,
871                &migration.script,
872                &config.safety,
873            )
874            .await?;
875            if safety_report.overall_verdict == crate::safety::SafetyVerdict::Danger
876                && config.safety.block_on_danger
877                && !migration.directives.safety_override
878                && !force_override
879            {
880                return Err(WaypointError::MigrationBlocked {
881                    script: migration.script.clone(),
882                    reason: safety_report.suggestions.join("; "),
883                });
884            }
885        }
886    }
887
888    // Guard preconditions (before batch transaction)
889    let mut skipped_scripts: HashSet<&str> = HashSet::new();
890    for migration in &pending_versioned {
891        match evaluate_require_guards(client, schema, migration, config).await? {
892            GuardAction::Continue => {}
893            GuardAction::Skip => {
894                skipped_scripts.insert(&migration.script);
895            }
896            GuardAction::Error(e) => return Err(e),
897        }
898    }
899    // Remove skipped migrations
900    pending_versioned.retain(|m| !skipped_scripts.contains(m.script.as_str()));
901
902    let mut report = MigrateReport {
903        migrations_applied: 0,
904        total_time_ms: 0,
905        details: Vec::new(),
906        hooks_executed: 0,
907        hooks_time_ms: 0,
908    };
909
910    // Run beforeMigrate hooks (outside the batch transaction)
911    let before_placeholders = build_placeholders(
912        &config.placeholders,
913        schema,
914        &setup.db_user,
915        &setup.db_name,
916        "beforeMigrate",
917    );
918    let (count, ms) = hooks::run_hooks(
919        client,
920        &setup.all_hooks,
921        &HookType::BeforeMigrate,
922        &before_placeholders,
923    )
924    .await?;
925    report.hooks_executed += count;
926    report.hooks_time_ms += ms;
927
928    // Nothing to apply?
929    if pending_versioned.is_empty() && pending_repeatables.is_empty() {
930        // Run afterMigrate hooks
931        let after_placeholders = build_placeholders(
932            &config.placeholders,
933            schema,
934            &setup.db_user,
935            &setup.db_name,
936            "afterMigrate",
937        );
938        let (count, ms) = hooks::run_hooks(
939            client,
940            &setup.all_hooks,
941            &HookType::AfterMigrate,
942            &after_placeholders,
943        )
944        .await?;
945        report.hooks_executed += count;
946        report.hooks_time_ms += ms;
947        return Ok(report);
948    }
949
950    // Capture before-snapshot for auto-reversal (before batch transaction)
951    let before_snapshot = if config.reversals.enabled {
952        match crate::reversal::capture_before(client, schema).await {
953            Ok(snap) => Some(snap),
954            Err(e) => {
955                log::warn!(
956                    "Failed to capture before-snapshot for batch reversal: {}",
957                    e
958                );
959                None
960            }
961        }
962    } else {
963        None
964    };
965
966    // ── BEGIN batch transaction ──
967    let batch_start = std::time::Instant::now();
968    client.batch_execute("BEGIN").await?;
969
970    let installed_by = &setup.installed_by;
971    let batch_result = async {
972        // Apply versioned migrations inside the transaction
973        for migration in &pending_versioned {
974            let version = migration.version().unwrap();
975            let each_placeholders = build_placeholders(
976                &config.placeholders,
977                schema,
978                &setup.db_user,
979                &setup.db_name,
980                &migration.script,
981            );
982
983            // beforeEachMigrate hooks (inside transaction)
984            let (count, ms) = hooks::run_hooks(
985                client,
986                &setup.all_hooks,
987                &HookType::BeforeEachMigrate,
988                &each_placeholders,
989            )
990            .await?;
991            report.hooks_executed += count;
992            report.hooks_time_ms += ms;
993
994            let sql = replace_placeholders(&migration.sql, &each_placeholders)?;
995            let start = std::time::Instant::now();
996            client
997                .batch_execute(&sql)
998                .await
999                .map_err(|e| WaypointError::MigrationFailed {
1000                    script: migration.script.clone(),
1001                    reason: crate::error::format_db_error(&e),
1002                })?;
1003            let exec_time = start.elapsed().as_millis() as i32;
1004
1005            // Record history inside the same transaction
1006            let version_str = Some(version.raw.as_str());
1007            let type_str = migration.migration_type().to_string();
1008            history::insert_applied_migration(
1009                client,
1010                schema,
1011                table,
1012                version_str,
1013                &migration.description,
1014                &type_str,
1015                &migration.script,
1016                Some(migration.checksum),
1017                installed_by,
1018                exec_time,
1019                true,
1020            )
1021            .await?;
1022
1023            // afterEachMigrate hooks (inside transaction)
1024            let (count, ms) = hooks::run_hooks(
1025                client,
1026                &setup.all_hooks,
1027                &HookType::AfterEachMigrate,
1028                &each_placeholders,
1029            )
1030            .await?;
1031            report.hooks_executed += count;
1032            report.hooks_time_ms += ms;
1033
1034            report.migrations_applied += 1;
1035            report.total_time_ms += exec_time;
1036            report.details.push(MigrateDetail {
1037                version: Some(version.raw.clone()),
1038                description: migration.description.clone(),
1039                script: migration.script.clone(),
1040                execution_time_ms: exec_time,
1041            });
1042        }
1043
1044        // Apply repeatable migrations inside the transaction
1045        for migration in &pending_repeatables {
1046            let each_placeholders = build_placeholders(
1047                &config.placeholders,
1048                schema,
1049                &setup.db_user,
1050                &setup.db_name,
1051                &migration.script,
1052            );
1053
1054            let (count, ms) = hooks::run_hooks(
1055                client,
1056                &setup.all_hooks,
1057                &HookType::BeforeEachMigrate,
1058                &each_placeholders,
1059            )
1060            .await?;
1061            report.hooks_executed += count;
1062            report.hooks_time_ms += ms;
1063
1064            let sql = replace_placeholders(&migration.sql, &each_placeholders)?;
1065            let start = std::time::Instant::now();
1066            client
1067                .batch_execute(&sql)
1068                .await
1069                .map_err(|e| WaypointError::MigrationFailed {
1070                    script: migration.script.clone(),
1071                    reason: crate::error::format_db_error(&e),
1072                })?;
1073            let exec_time = start.elapsed().as_millis() as i32;
1074
1075            let type_str = migration.migration_type().to_string();
1076            history::insert_applied_migration(
1077                client,
1078                schema,
1079                table,
1080                None,
1081                &migration.description,
1082                &type_str,
1083                &migration.script,
1084                Some(migration.checksum),
1085                installed_by,
1086                exec_time,
1087                true,
1088            )
1089            .await?;
1090
1091            let (count, ms) = hooks::run_hooks(
1092                client,
1093                &setup.all_hooks,
1094                &HookType::AfterEachMigrate,
1095                &each_placeholders,
1096            )
1097            .await?;
1098            report.hooks_executed += count;
1099            report.hooks_time_ms += ms;
1100
1101            report.migrations_applied += 1;
1102            report.total_time_ms += exec_time;
1103            report.details.push(MigrateDetail {
1104                version: None,
1105                description: migration.description.clone(),
1106                script: migration.script.clone(),
1107                execution_time_ms: exec_time,
1108            });
1109        }
1110
1111        Ok::<(), WaypointError>(())
1112    }
1113    .await;
1114
1115    match batch_result {
1116        Ok(()) => {
1117            client.batch_execute("COMMIT").await?;
1118            report.total_time_ms = batch_start.elapsed().as_millis() as i32;
1119
1120            // Generate and store reversals for each versioned migration in the batch
1121            if let Some(ref before) = before_snapshot {
1122                for migration in &pending_versioned {
1123                    if let Some(ver) = migration.version() {
1124                        match crate::reversal::generate_reversal(
1125                            client,
1126                            schema,
1127                            before,
1128                            config.reversals.warn_data_loss,
1129                        )
1130                        .await
1131                        {
1132                            Ok(result) => {
1133                                if let Some(ref reversal_sql) = result.reversal_sql {
1134                                    if let Err(e) = crate::reversal::store_reversal(
1135                                        client,
1136                                        schema,
1137                                        table,
1138                                        &ver.raw,
1139                                        reversal_sql,
1140                                    )
1141                                    .await
1142                                    {
1143                                        log::warn!(
1144                                            "Failed to store reversal SQL; version={}, error={}",
1145                                            ver.raw,
1146                                            e
1147                                        );
1148                                    }
1149                                }
1150                                for warning in &result.warnings {
1151                                    log::warn!(
1152                                        "Reversal warning for {}: {}",
1153                                        migration.script,
1154                                        warning
1155                                    );
1156                                }
1157                            }
1158                            Err(e) => {
1159                                log::warn!(
1160                                    "Failed to generate reversal; script={}, error={}",
1161                                    migration.script,
1162                                    e
1163                                );
1164                            }
1165                        }
1166                    }
1167                }
1168            }
1169        }
1170        Err(e) => {
1171            if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
1172                log::error!("Failed to rollback batch transaction: {}", rollback_err);
1173            }
1174            log::error!("Batch migration failed, all changes rolled back: {}", e);
1175            return Err(e);
1176        }
1177    }
1178
1179    // Run afterMigrate hooks (outside the batch transaction)
1180    let after_placeholders = build_placeholders(
1181        &config.placeholders,
1182        schema,
1183        &setup.db_user,
1184        &setup.db_name,
1185        "afterMigrate",
1186    );
1187    let (count, ms) = hooks::run_hooks(
1188        client,
1189        &setup.all_hooks,
1190        &HookType::AfterMigrate,
1191        &after_placeholders,
1192    )
1193    .await?;
1194    report.hooks_executed += count;
1195    report.hooks_time_ms += ms;
1196
1197    Ok(report)
1198}
1199
1200/// Apply a single migration within a transaction.
1201///
1202/// Executes the migration SQL, records it in the history table, and optionally
1203/// commits the transaction. When `hold_transaction` is `true`, the transaction
1204/// is left open so the caller can evaluate ensure guards before committing.
1205#[allow(clippy::too_many_arguments)]
1206async fn apply_migration(
1207    client: &Client,
1208    config: &WaypointConfig,
1209    migration: &ResolvedMigration,
1210    schema: &str,
1211    table: &str,
1212    installed_by: &str,
1213    db_user: &str,
1214    db_name: &str,
1215    hold_transaction: bool,
1216) -> Result<i32> {
1217    log::info!(
1218        "Applying migration; migration={}, schema={}",
1219        migration.script,
1220        schema
1221    );
1222
1223    // Build placeholders
1224    let placeholders = build_placeholders(
1225        &config.placeholders,
1226        schema,
1227        db_user,
1228        db_name,
1229        &migration.script,
1230    );
1231
1232    // Replace placeholders in SQL
1233    let sql = replace_placeholders(&migration.sql, &placeholders)?;
1234
1235    let version_str = migration.version().map(|v| v.raw.as_str());
1236    let type_str = migration.migration_type().to_string();
1237
1238    // Execute migration SQL and history insert atomically in one transaction
1239    let start = std::time::Instant::now();
1240    client.batch_execute("BEGIN").await?;
1241
1242    match client.batch_execute(&sql).await {
1243        Ok(()) => {
1244            let exec_time = start.elapsed().as_millis() as i32;
1245            // Record success inside the same transaction
1246            match history::insert_applied_migration(
1247                client,
1248                schema,
1249                table,
1250                version_str,
1251                &migration.description,
1252                &type_str,
1253                &migration.script,
1254                Some(migration.checksum),
1255                installed_by,
1256                exec_time,
1257                true,
1258            )
1259            .await
1260            {
1261                Ok(()) => {
1262                    if !hold_transaction {
1263                        client.batch_execute("COMMIT").await?;
1264                    }
1265                    Ok(exec_time)
1266                }
1267                Err(e) => {
1268                    if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
1269                        log::error!("Failed to rollback transaction: {}", rollback_err);
1270                    }
1271                    Err(e)
1272                }
1273            }
1274        }
1275        Err(e) => {
1276            if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
1277                log::error!("Failed to rollback transaction: {}", rollback_err);
1278            }
1279
1280            // Record failure — best-effort outside the rolled-back transaction
1281            if let Err(record_err) = history::insert_applied_migration(
1282                client,
1283                schema,
1284                table,
1285                version_str,
1286                &migration.description,
1287                &type_str,
1288                &migration.script,
1289                Some(migration.checksum),
1290                installed_by,
1291                0,
1292                false,
1293            )
1294            .await
1295            {
1296                log::warn!(
1297                    "Failed to record migration failure in history table; script={}, error={}",
1298                    migration.script,
1299                    record_err
1300                );
1301            }
1302
1303            // Extract detailed error message
1304            let reason = crate::error::format_db_error(&e);
1305            log::error!(
1306                "Migration failed; script={}, reason={}",
1307                migration.script,
1308                reason
1309            );
1310            Err(WaypointError::MigrationFailed {
1311                script: migration.script.clone(),
1312                reason,
1313            })
1314        }
1315    }
1316}
1317
1318#[cfg(test)]
1319mod tests {
1320    use super::*;
1321
1322    #[test]
1323    fn test_detect_concurrent_index() {
1324        let sql = "CREATE INDEX CONCURRENTLY idx_users_email ON users (email);";
1325        let result = validate_batch_compatible("V5__Add_index.sql", sql);
1326        assert!(result.is_err());
1327        let err = result.unwrap_err();
1328        match err {
1329            WaypointError::NonTransactionalStatement { script, .. } => {
1330                assert_eq!(script, "V5__Add_index.sql");
1331            }
1332            _ => panic!("Expected NonTransactionalStatement, got {:?}", err),
1333        }
1334    }
1335
1336    #[test]
1337    fn test_detect_drop_index_concurrently() {
1338        let sql = "DROP INDEX CONCURRENTLY idx_users_email;";
1339        let result = validate_batch_compatible("V6__Drop_index.sql", sql);
1340        assert!(result.is_err());
1341        match result.unwrap_err() {
1342            WaypointError::NonTransactionalStatement { statement, .. } => {
1343                assert!(statement.contains("DROP INDEX CONCURRENTLY"));
1344            }
1345            other => panic!("Expected NonTransactionalStatement, got {:?}", other),
1346        }
1347    }
1348
1349    #[test]
1350    fn test_detect_vacuum() {
1351        let sql = "VACUUM ANALYZE users;";
1352        let result = validate_batch_compatible("V7__Vacuum.sql", sql);
1353        assert!(result.is_err());
1354        match result.unwrap_err() {
1355            WaypointError::NonTransactionalStatement { statement, .. } => {
1356                assert_eq!(statement, "VACUUM");
1357            }
1358            other => panic!("Expected NonTransactionalStatement, got {:?}", other),
1359        }
1360    }
1361
1362    #[test]
1363    fn test_detect_create_database() {
1364        let sql = "CREATE DATABASE newdb;";
1365        let result = validate_batch_compatible("V8__Create_db.sql", sql);
1366        assert!(result.is_err());
1367    }
1368
1369    #[test]
1370    fn test_batch_compatible_normal_ddl() {
1371        let sql =
1372            "CREATE TABLE users (id SERIAL PRIMARY KEY); CREATE INDEX idx_users ON users (id);";
1373        let result = validate_batch_compatible("V1__Init.sql", sql);
1374        assert!(result.is_ok());
1375    }
1376
1377    #[test]
1378    fn test_should_run_in_environment_no_directives() {
1379        let directives = MigrationDirectives::default();
1380        assert!(should_run_in_environment(&directives, Some("production")));
1381        assert!(should_run_in_environment(&directives, None));
1382    }
1383
1384    #[test]
1385    fn test_should_run_in_environment_matches() {
1386        let directives = MigrationDirectives {
1387            env: vec!["production".to_string(), "staging".to_string()],
1388            ..Default::default()
1389        };
1390        assert!(should_run_in_environment(&directives, Some("production")));
1391        assert!(should_run_in_environment(&directives, Some("staging")));
1392        assert!(!should_run_in_environment(&directives, Some("dev")));
1393    }
1394
1395    #[test]
1396    fn test_should_run_in_environment_case_insensitive() {
1397        let directives = MigrationDirectives {
1398            env: vec!["PROD".to_string()],
1399            ..Default::default()
1400        };
1401        assert!(should_run_in_environment(&directives, Some("prod")));
1402        assert!(should_run_in_environment(&directives, Some("PROD")));
1403        assert!(should_run_in_environment(&directives, Some("Prod")));
1404        assert!(!should_run_in_environment(&directives, Some("dev")));
1405    }
1406
1407    #[test]
1408    fn test_should_run_in_environment_no_env_configured() {
1409        let directives = MigrationDirectives {
1410            env: vec!["prod".to_string()],
1411            ..Default::default()
1412        };
1413        // No environment configured = runs everything
1414        assert!(should_run_in_environment(&directives, None));
1415    }
1416}