Skip to main content

waypoint_core/
migration.rs

1//! Migration file parsing, scanning, and types.
2//!
3//! Supports versioned (`V{version}__{desc}.sql`) and repeatable (`R__{desc}.sql`) migrations.
4
5use std::cmp::Ordering;
6use std::fmt;
7use std::sync::LazyLock;
8
9use regex_lite::Regex;
10
11use crate::checksum::calculate_checksum;
12use crate::directive::{self, MigrationDirectives};
13use crate::error::{Result, WaypointError};
14use crate::hooks;
15
16static VERSIONED_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^V([\d._]+)__(.+)$").unwrap());
17static UNDO_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^U([\d._]+)__(.+)$").unwrap());
18static REPEATABLE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^R__(.+)$").unwrap());
19
20/// A parsed migration version, supporting dotted numeric segments (e.g., "1.2.3").
21#[derive(Debug, Clone, Eq, PartialEq)]
22pub struct MigrationVersion {
23    /// Parsed numeric segments of the version (e.g., `[1, 2, 3]` for `"1.2.3"`).
24    pub segments: Vec<u64>,
25    /// Original version string as it appeared in the filename.
26    pub raw: String,
27}
28
29impl MigrationVersion {
30    /// Parse a version string like `"1.2.3"` or `"1_2"` into segments.
31    pub fn parse(raw: &str) -> Result<Self> {
32        if raw.is_empty() {
33            return Err(WaypointError::MigrationParseError(
34                "Version string is empty".to_string(),
35            ));
36        }
37
38        // Support both "." and "_" as segment separators
39        let segments: std::result::Result<Vec<u64>, _> =
40            raw.split(['.', '_']).map(|s| s.parse::<u64>()).collect();
41
42        let segments = segments.map_err(|e| {
43            WaypointError::MigrationParseError(format!(
44                "Invalid version segment in '{}': {}",
45                raw, e
46            ))
47        })?;
48
49        Ok(MigrationVersion {
50            segments,
51            raw: raw.to_string(),
52        })
53    }
54}
55
56impl Ord for MigrationVersion {
57    fn cmp(&self, other: &Self) -> Ordering {
58        let max_len = self.segments.len().max(other.segments.len());
59        for i in 0..max_len {
60            let a = self.segments.get(i).copied().unwrap_or(0);
61            let b = other.segments.get(i).copied().unwrap_or(0);
62            match a.cmp(&b) {
63                Ordering::Equal => continue,
64                ord => return ord,
65            }
66        }
67        Ordering::Equal
68    }
69}
70
71impl PartialOrd for MigrationVersion {
72    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
73        Some(self.cmp(other))
74    }
75}
76
77impl fmt::Display for MigrationVersion {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        write!(f, "{}", self.raw)
80    }
81}
82
83/// The type of a migration (for display/serialization).
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum MigrationType {
86    /// V{version}__{description}.sql
87    Versioned,
88    /// R__{description}.sql
89    Repeatable,
90    /// U{version}__{description}.sql
91    Undo,
92}
93
94impl fmt::Display for MigrationType {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        match self {
97            MigrationType::Versioned => write!(f, "SQL"),
98            MigrationType::Repeatable => write!(f, "SQL_REPEATABLE"),
99            MigrationType::Undo => write!(f, "UNDO_SQL"),
100        }
101    }
102}
103
104/// Type-safe encoding of the migration variant.
105///
106/// Versioned migrations always have a version; repeatable migrations never do.
107/// This eliminates the `Option<MigrationVersion>` + `MigrationType` redundancy.
108#[derive(Debug, Clone)]
109pub enum MigrationKind {
110    /// A versioned migration with an associated version number.
111    Versioned(MigrationVersion),
112    /// A repeatable migration that is re-applied whenever its checksum changes.
113    Repeatable,
114    /// An undo migration that reverses a specific versioned migration.
115    Undo(MigrationVersion),
116}
117
118/// A migration file discovered on disk.
119#[derive(Debug, Clone)]
120pub struct ResolvedMigration {
121    /// Whether this is a versioned, repeatable, or undo migration (with version if applicable).
122    pub kind: MigrationKind,
123    /// Human-readable description extracted from the filename.
124    pub description: String,
125    /// Original filename of the migration script (e.g., `V1__Create_users.sql`).
126    pub script: String,
127    /// CRC32 checksum of the migration SQL content.
128    pub checksum: i32,
129    /// Raw SQL content of the migration file.
130    pub sql: String,
131    /// Parsed directives from SQL comments (e.g., `@depends`, `@environment`).
132    pub directives: MigrationDirectives,
133}
134
135impl ResolvedMigration {
136    /// Get the version if this is a versioned or undo migration.
137    pub fn version(&self) -> Option<&MigrationVersion> {
138        match &self.kind {
139            MigrationKind::Versioned(v) | MigrationKind::Undo(v) => Some(v),
140            MigrationKind::Repeatable => None,
141        }
142    }
143
144    /// Get the migration type for display/serialization.
145    pub fn migration_type(&self) -> MigrationType {
146        match &self.kind {
147            MigrationKind::Versioned(_) => MigrationType::Versioned,
148            MigrationKind::Repeatable => MigrationType::Repeatable,
149            MigrationKind::Undo(_) => MigrationType::Undo,
150        }
151    }
152
153    /// Whether this is a versioned migration.
154    pub fn is_versioned(&self) -> bool {
155        matches!(&self.kind, MigrationKind::Versioned(_))
156    }
157
158    /// Whether this is an undo migration.
159    pub fn is_undo(&self) -> bool {
160        matches!(&self.kind, MigrationKind::Undo(_))
161    }
162}
163
164/// Parse a migration filename into its components.
165///
166/// Expected patterns:
167///   V{version}__{description}.sql  — versioned migration
168///   R__{description}.sql           — repeatable migration
169pub fn parse_migration_filename(filename: &str) -> Result<(MigrationKind, String)> {
170    // Strip .sql extension
171    let stem = filename.strip_suffix(".sql").ok_or_else(|| {
172        WaypointError::MigrationParseError(format!(
173            "Migration file '{}' does not have .sql extension",
174            filename
175        ))
176    })?;
177
178    if let Some(caps) = VERSIONED_RE.captures(stem) {
179        let version_str = caps.get(1).unwrap().as_str();
180        let description = caps.get(2).unwrap().as_str().replace('_', " ");
181        let version = MigrationVersion::parse(version_str)?;
182        Ok((MigrationKind::Versioned(version), description))
183    } else if let Some(caps) = UNDO_RE.captures(stem) {
184        let version_str = caps.get(1).unwrap().as_str();
185        let description = caps.get(2).unwrap().as_str().replace('_', " ");
186        let version = MigrationVersion::parse(version_str)?;
187        Ok((MigrationKind::Undo(version), description))
188    } else if let Some(caps) = REPEATABLE_RE.captures(stem) {
189        let description = caps.get(1).unwrap().as_str().replace('_', " ");
190        Ok((MigrationKind::Repeatable, description))
191    } else {
192        Err(WaypointError::MigrationParseError(format!(
193            "Migration file '{}' does not match V{{version}}__{{description}}.sql, U{{version}}__{{description}}.sql, or R__{{description}}.sql pattern",
194            filename
195        )))
196    }
197}
198
199/// Scan migration locations for SQL files and parse them into ResolvedMigrations.
200pub fn scan_migrations(locations: &[std::path::PathBuf]) -> Result<Vec<ResolvedMigration>> {
201    let mut migrations = Vec::new();
202
203    for location in locations {
204        if !location.exists() {
205            log::warn!("Migration location does not exist: {}", location.display());
206            continue;
207        }
208
209        let entries = std::fs::read_dir(location).map_err(|e| {
210            WaypointError::IoError(std::io::Error::new(
211                e.kind(),
212                format!(
213                    "Failed to read migration directory '{}': {}",
214                    location.display(),
215                    e
216                ),
217            ))
218        })?;
219
220        for entry in entries {
221            let entry = entry?;
222            let path = entry.path();
223
224            if !path.is_file() {
225                continue;
226            }
227
228            let filename = match path.file_name().and_then(|n| n.to_str()) {
229                Some(name) => name.to_string(),
230                None => continue,
231            };
232
233            // Skip non-SQL files
234            if !filename.ends_with(".sql") {
235                continue;
236            }
237
238            // Skip hook callback files
239            if hooks::is_hook_file(&filename) {
240                continue;
241            }
242
243            // Skip files that don't start with V, U, or R
244            if !filename.starts_with('V')
245                && !filename.starts_with('U')
246                && !filename.starts_with('R')
247            {
248                continue;
249            }
250
251            let (kind, description) = match parse_migration_filename(&filename) {
252                Ok(result) => result,
253                Err(e) => {
254                    log::warn!("Skipping malformed migration file '{}': {}", filename, e);
255                    continue;
256                }
257            };
258            let sql = std::fs::read_to_string(&path)?;
259            let checksum = calculate_checksum(&sql);
260            let directives = directive::parse_directives(&sql);
261
262            migrations.push(ResolvedMigration {
263                kind,
264                description,
265                script: filename,
266                checksum,
267                sql,
268                directives,
269            });
270        }
271    }
272
273    // Sort: versioned by version, then undo by version, then repeatable by description
274    migrations.sort_by(|a, b| {
275        // Order groups: Versioned first, then Undo, then Repeatable
276        fn group_order(kind: &MigrationKind) -> u8 {
277            match kind {
278                MigrationKind::Versioned(_) => 0,
279                MigrationKind::Undo(_) => 1,
280                MigrationKind::Repeatable => 2,
281            }
282        }
283        let ga = group_order(&a.kind);
284        let gb = group_order(&b.kind);
285        if ga != gb {
286            return ga.cmp(&gb);
287        }
288        match (&a.kind, &b.kind) {
289            (MigrationKind::Versioned(va), MigrationKind::Versioned(vb)) => va.cmp(vb),
290            (MigrationKind::Undo(va), MigrationKind::Undo(vb)) => va.cmp(vb),
291            (MigrationKind::Repeatable, MigrationKind::Repeatable) => {
292                a.description.cmp(&b.description)
293            }
294            _ => Ordering::Equal,
295        }
296    });
297
298    // Detect duplicate versions
299    let mut seen_versions: std::collections::HashSet<String> = std::collections::HashSet::new();
300    for m in &migrations {
301        if let Some(v) = m.version() {
302            let prefix = if m.is_versioned() { "V" } else { "U" };
303            let key = format!("{}{}", prefix, v.raw);
304            if !seen_versions.insert(key) {
305                return Err(WaypointError::ValidationFailed(format!(
306                    "Duplicate migration version '{}' found in file '{}'. Each version must be unique.",
307                    v.raw, m.script
308                )));
309            }
310        }
311    }
312
313    Ok(migrations)
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_version_parsing() {
322        let v = MigrationVersion::parse("1").unwrap();
323        assert_eq!(v.segments, vec![1]);
324
325        let v = MigrationVersion::parse("1.2.3").unwrap();
326        assert_eq!(v.segments, vec![1, 2, 3]);
327
328        let v = MigrationVersion::parse("1_2_3").unwrap();
329        assert_eq!(v.segments, vec![1, 2, 3]);
330    }
331
332    #[test]
333    fn test_version_ordering() {
334        let v1 = MigrationVersion::parse("1").unwrap();
335        let v2 = MigrationVersion::parse("2").unwrap();
336        let v1_9 = MigrationVersion::parse("1.9").unwrap();
337        let v1_10 = MigrationVersion::parse("1.10").unwrap();
338        let v1_2 = MigrationVersion::parse("1.2").unwrap();
339        let v1_2_0 = MigrationVersion::parse("1.2.0").unwrap();
340
341        assert!(v1 < v2);
342        assert!(v1_9 < v1_10); // Numeric, not string comparison
343        assert!(v1_2 < v1_9);
344        assert_eq!(v1_2.cmp(&v1_2_0), Ordering::Equal); // Trailing zeros are equal
345    }
346
347    #[test]
348    fn test_version_parse_error() {
349        assert!(MigrationVersion::parse("").is_err());
350        assert!(MigrationVersion::parse("abc").is_err());
351    }
352
353    #[test]
354    fn test_parse_versioned_filename() {
355        let (kind, desc) = parse_migration_filename("V1__Create_users.sql").unwrap();
356        match kind {
357            MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1]),
358            _ => panic!("Expected Versioned"),
359        }
360        assert_eq!(desc, "Create users");
361    }
362
363    #[test]
364    fn test_parse_versioned_dotted_version() {
365        let (kind, desc) = parse_migration_filename("V1.2.3__Add_column.sql").unwrap();
366        match kind {
367            MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1, 2, 3]),
368            _ => panic!("Expected Versioned"),
369        }
370        assert_eq!(desc, "Add column");
371    }
372
373    #[test]
374    fn test_parse_repeatable_filename() {
375        let (kind, desc) = parse_migration_filename("R__Create_user_view.sql").unwrap();
376        assert!(matches!(kind, MigrationKind::Repeatable));
377        assert_eq!(desc, "Create user view");
378    }
379
380    #[test]
381    fn test_parse_invalid_filename() {
382        assert!(parse_migration_filename("random.sql").is_err());
383        assert!(parse_migration_filename("V1_missing_separator.sql").is_err());
384        assert!(parse_migration_filename("V1__no_ext").is_err());
385    }
386
387    #[test]
388    fn test_parse_undo_filename() {
389        let (kind, desc) = parse_migration_filename("U1__Create_users.sql").unwrap();
390        match kind {
391            MigrationKind::Undo(v) => assert_eq!(v.segments, vec![1]),
392            _ => panic!("Expected Undo"),
393        }
394        assert_eq!(desc, "Create users");
395    }
396
397    #[test]
398    fn test_parse_undo_dotted_version() {
399        let (kind, desc) = parse_migration_filename("U1.2.3__Add_column.sql").unwrap();
400        match kind {
401            MigrationKind::Undo(v) => assert_eq!(v.segments, vec![1, 2, 3]),
402            _ => panic!("Expected Undo"),
403        }
404        assert_eq!(desc, "Add column");
405    }
406
407    #[test]
408    fn test_malformed_filename_is_skipped() {
409        // This tests the parse function itself
410        assert!(parse_migration_filename("random.sql").is_err());
411        assert!(parse_migration_filename("V1_missing_separator.sql").is_err());
412    }
413
414    #[test]
415    fn test_undo_is_undo() {
416        let m = ResolvedMigration {
417            kind: MigrationKind::Undo(MigrationVersion::parse("1").unwrap()),
418            description: "test".to_string(),
419            script: "U1__test.sql".to_string(),
420            checksum: 0,
421            sql: String::new(),
422            directives: MigrationDirectives::default(),
423        };
424        assert!(m.is_undo());
425        assert!(!m.is_versioned());
426        assert_eq!(m.migration_type(), MigrationType::Undo);
427        assert_eq!(m.migration_type().to_string(), "UNDO_SQL");
428    }
429}