Skip to main content

waypoint_core/
migration.rs

1//! Migration file parsing, scanning, and types.
2//!
3//! Supports versioned (`V{version}__{desc}.sql`) and repeatable (`R__{desc}.sql`) migrations.
4
5use std::cmp::Ordering;
6use std::fmt;
7use std::sync::LazyLock;
8
9use regex::Regex;
10
11use crate::checksum::calculate_checksum;
12use crate::error::{Result, WaypointError};
13use crate::hooks;
14
15static VERSIONED_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^V([\d._]+)__(.+)$").unwrap());
16static REPEATABLE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^R__(.+)$").unwrap());
17
18/// A parsed migration version, supporting dotted numeric segments (e.g., "1.2.3").
19#[derive(Debug, Clone, Eq, PartialEq)]
20pub struct MigrationVersion {
21    pub segments: Vec<u64>,
22    pub raw: String,
23}
24
25impl MigrationVersion {
26    /// Parse a version string like `"1.2.3"` or `"1_2"` into segments.
27    pub fn parse(raw: &str) -> Result<Self> {
28        if raw.is_empty() {
29            return Err(WaypointError::MigrationParseError(
30                "Version string is empty".to_string(),
31            ));
32        }
33
34        // Support both "." and "_" as segment separators
35        let segments: std::result::Result<Vec<u64>, _> =
36            raw.split(['.', '_']).map(|s| s.parse::<u64>()).collect();
37
38        let segments = segments.map_err(|e| {
39            WaypointError::MigrationParseError(format!(
40                "Invalid version segment in '{}': {}",
41                raw, e
42            ))
43        })?;
44
45        Ok(MigrationVersion {
46            segments,
47            raw: raw.to_string(),
48        })
49    }
50}
51
52impl Ord for MigrationVersion {
53    fn cmp(&self, other: &Self) -> Ordering {
54        let max_len = self.segments.len().max(other.segments.len());
55        for i in 0..max_len {
56            let a = self.segments.get(i).copied().unwrap_or(0);
57            let b = other.segments.get(i).copied().unwrap_or(0);
58            match a.cmp(&b) {
59                Ordering::Equal => continue,
60                ord => return ord,
61            }
62        }
63        Ordering::Equal
64    }
65}
66
67impl PartialOrd for MigrationVersion {
68    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
69        Some(self.cmp(other))
70    }
71}
72
73impl fmt::Display for MigrationVersion {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        write!(f, "{}", self.raw)
76    }
77}
78
79/// The type of a migration (for display/serialization).
80#[derive(Debug, Clone, PartialEq, Eq)]
81pub enum MigrationType {
82    /// V{version}__{description}.sql
83    Versioned,
84    /// R__{description}.sql
85    Repeatable,
86}
87
88impl fmt::Display for MigrationType {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        match self {
91            MigrationType::Versioned => write!(f, "SQL"),
92            MigrationType::Repeatable => write!(f, "SQL_REPEATABLE"),
93        }
94    }
95}
96
97/// Type-safe encoding of the migration variant.
98///
99/// Versioned migrations always have a version; repeatable migrations never do.
100/// This eliminates the `Option<MigrationVersion>` + `MigrationType` redundancy.
101#[derive(Debug, Clone)]
102pub enum MigrationKind {
103    Versioned(MigrationVersion),
104    Repeatable,
105}
106
107/// A migration file discovered on disk.
108#[derive(Debug, Clone)]
109pub struct ResolvedMigration {
110    pub kind: MigrationKind,
111    pub description: String,
112    pub script: String,
113    pub checksum: i32,
114    pub sql: String,
115}
116
117impl ResolvedMigration {
118    /// Get the version if this is a versioned migration.
119    pub fn version(&self) -> Option<&MigrationVersion> {
120        match &self.kind {
121            MigrationKind::Versioned(v) => Some(v),
122            MigrationKind::Repeatable => None,
123        }
124    }
125
126    /// Get the migration type for display/serialization.
127    pub fn migration_type(&self) -> MigrationType {
128        match &self.kind {
129            MigrationKind::Versioned(_) => MigrationType::Versioned,
130            MigrationKind::Repeatable => MigrationType::Repeatable,
131        }
132    }
133
134    /// Whether this is a versioned migration.
135    pub fn is_versioned(&self) -> bool {
136        matches!(&self.kind, MigrationKind::Versioned(_))
137    }
138}
139
140/// Parse a migration filename into its components.
141///
142/// Expected patterns:
143///   V{version}__{description}.sql  — versioned migration
144///   R__{description}.sql           — repeatable migration
145pub fn parse_migration_filename(filename: &str) -> Result<(MigrationKind, String)> {
146    // Strip .sql extension
147    let stem = filename.strip_suffix(".sql").ok_or_else(|| {
148        WaypointError::MigrationParseError(format!(
149            "Migration file '{}' does not have .sql extension",
150            filename
151        ))
152    })?;
153
154    if let Some(caps) = VERSIONED_RE.captures(stem) {
155        let version_str = caps.get(1).unwrap().as_str();
156        let description = caps.get(2).unwrap().as_str().replace('_', " ");
157        let version = MigrationVersion::parse(version_str)?;
158        Ok((MigrationKind::Versioned(version), description))
159    } else if let Some(caps) = REPEATABLE_RE.captures(stem) {
160        let description = caps.get(1).unwrap().as_str().replace('_', " ");
161        Ok((MigrationKind::Repeatable, description))
162    } else {
163        Err(WaypointError::MigrationParseError(format!(
164            "Migration file '{}' does not match V{{version}}__{{description}}.sql or R__{{description}}.sql pattern",
165            filename
166        )))
167    }
168}
169
170/// Scan migration locations for SQL files and parse them into ResolvedMigrations.
171pub fn scan_migrations(locations: &[std::path::PathBuf]) -> Result<Vec<ResolvedMigration>> {
172    let mut migrations = Vec::new();
173
174    for location in locations {
175        if !location.exists() {
176            tracing::warn!("Migration location does not exist: {}", location.display());
177            continue;
178        }
179
180        let entries = std::fs::read_dir(location).map_err(|e| {
181            WaypointError::IoError(std::io::Error::new(
182                e.kind(),
183                format!(
184                    "Failed to read migration directory '{}': {}",
185                    location.display(),
186                    e
187                ),
188            ))
189        })?;
190
191        for entry in entries {
192            let entry = entry?;
193            let path = entry.path();
194
195            if !path.is_file() {
196                continue;
197            }
198
199            let filename = match path.file_name().and_then(|n| n.to_str()) {
200                Some(name) => name.to_string(),
201                None => continue,
202            };
203
204            // Skip non-SQL files
205            if !filename.ends_with(".sql") {
206                continue;
207            }
208
209            // Skip hook callback files
210            if hooks::is_hook_file(&filename) {
211                continue;
212            }
213
214            // Skip files that don't start with V or R
215            if !filename.starts_with('V') && !filename.starts_with('R') {
216                continue;
217            }
218
219            let (kind, description) = parse_migration_filename(&filename)?;
220            let sql = std::fs::read_to_string(&path)?;
221            let checksum = calculate_checksum(&sql);
222
223            migrations.push(ResolvedMigration {
224                kind,
225                description,
226                script: filename,
227                checksum,
228                sql,
229            });
230        }
231    }
232
233    // Sort: versioned by version, repeatable by description
234    migrations.sort_by(|a, b| match (&a.kind, &b.kind) {
235        (MigrationKind::Versioned(va), MigrationKind::Versioned(vb)) => va.cmp(vb),
236        (MigrationKind::Versioned(_), MigrationKind::Repeatable) => Ordering::Less,
237        (MigrationKind::Repeatable, MigrationKind::Versioned(_)) => Ordering::Greater,
238        (MigrationKind::Repeatable, MigrationKind::Repeatable) => a.description.cmp(&b.description),
239    });
240
241    Ok(migrations)
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_version_parsing() {
250        let v = MigrationVersion::parse("1").unwrap();
251        assert_eq!(v.segments, vec![1]);
252
253        let v = MigrationVersion::parse("1.2.3").unwrap();
254        assert_eq!(v.segments, vec![1, 2, 3]);
255
256        let v = MigrationVersion::parse("1_2_3").unwrap();
257        assert_eq!(v.segments, vec![1, 2, 3]);
258    }
259
260    #[test]
261    fn test_version_ordering() {
262        let v1 = MigrationVersion::parse("1").unwrap();
263        let v2 = MigrationVersion::parse("2").unwrap();
264        let v1_9 = MigrationVersion::parse("1.9").unwrap();
265        let v1_10 = MigrationVersion::parse("1.10").unwrap();
266        let v1_2 = MigrationVersion::parse("1.2").unwrap();
267        let v1_2_0 = MigrationVersion::parse("1.2.0").unwrap();
268
269        assert!(v1 < v2);
270        assert!(v1_9 < v1_10); // Numeric, not string comparison
271        assert!(v1_2 < v1_9);
272        assert_eq!(v1_2.cmp(&v1_2_0), Ordering::Equal); // Trailing zeros are equal
273    }
274
275    #[test]
276    fn test_version_parse_error() {
277        assert!(MigrationVersion::parse("").is_err());
278        assert!(MigrationVersion::parse("abc").is_err());
279    }
280
281    #[test]
282    fn test_parse_versioned_filename() {
283        let (kind, desc) = parse_migration_filename("V1__Create_users.sql").unwrap();
284        match kind {
285            MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1]),
286            _ => panic!("Expected Versioned"),
287        }
288        assert_eq!(desc, "Create users");
289    }
290
291    #[test]
292    fn test_parse_versioned_dotted_version() {
293        let (kind, desc) = parse_migration_filename("V1.2.3__Add_column.sql").unwrap();
294        match kind {
295            MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1, 2, 3]),
296            _ => panic!("Expected Versioned"),
297        }
298        assert_eq!(desc, "Add column");
299    }
300
301    #[test]
302    fn test_parse_repeatable_filename() {
303        let (kind, desc) = parse_migration_filename("R__Create_user_view.sql").unwrap();
304        assert!(matches!(kind, MigrationKind::Repeatable));
305        assert_eq!(desc, "Create user view");
306    }
307
308    #[test]
309    fn test_parse_invalid_filename() {
310        assert!(parse_migration_filename("random.sql").is_err());
311        assert!(parse_migration_filename("V1_missing_separator.sql").is_err());
312        assert!(parse_migration_filename("V1__no_ext").is_err());
313    }
314}