Skip to main content

waypoint_core/
migration.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::sync::LazyLock;
4
5use regex::Regex;
6
7use crate::checksum::calculate_checksum;
8use crate::error::{Result, WaypointError};
9use crate::hooks;
10
11static VERSIONED_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^V([\d._]+)__(.+)$").unwrap());
12static REPEATABLE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^R__(.+)$").unwrap());
13
14/// A parsed migration version, supporting dotted numeric segments (e.g., "1.2.3").
15#[derive(Debug, Clone, Eq, PartialEq)]
16pub struct MigrationVersion {
17    pub segments: Vec<u64>,
18    pub raw: String,
19}
20
21impl MigrationVersion {
22    pub fn parse(raw: &str) -> Result<Self> {
23        if raw.is_empty() {
24            return Err(WaypointError::MigrationParseError(
25                "Version string is empty".to_string(),
26            ));
27        }
28
29        // Support both "." and "_" as segment separators
30        let segments: std::result::Result<Vec<u64>, _> =
31            raw.split(['.', '_']).map(|s| s.parse::<u64>()).collect();
32
33        let segments = segments.map_err(|e| {
34            WaypointError::MigrationParseError(format!(
35                "Invalid version segment in '{}': {}",
36                raw, e
37            ))
38        })?;
39
40        Ok(MigrationVersion {
41            segments,
42            raw: raw.to_string(),
43        })
44    }
45}
46
47impl Ord for MigrationVersion {
48    fn cmp(&self, other: &Self) -> Ordering {
49        let max_len = self.segments.len().max(other.segments.len());
50        for i in 0..max_len {
51            let a = self.segments.get(i).copied().unwrap_or(0);
52            let b = other.segments.get(i).copied().unwrap_or(0);
53            match a.cmp(&b) {
54                Ordering::Equal => continue,
55                ord => return ord,
56            }
57        }
58        Ordering::Equal
59    }
60}
61
62impl PartialOrd for MigrationVersion {
63    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64        Some(self.cmp(other))
65    }
66}
67
68impl fmt::Display for MigrationVersion {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        write!(f, "{}", self.raw)
71    }
72}
73
74/// The type of a migration (for display/serialization).
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum MigrationType {
77    /// V{version}__{description}.sql
78    Versioned,
79    /// R__{description}.sql
80    Repeatable,
81}
82
83impl fmt::Display for MigrationType {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            MigrationType::Versioned => write!(f, "SQL"),
87            MigrationType::Repeatable => write!(f, "SQL_REPEATABLE"),
88        }
89    }
90}
91
92/// Type-safe encoding of the migration variant.
93///
94/// Versioned migrations always have a version; repeatable migrations never do.
95/// This eliminates the `Option<MigrationVersion>` + `MigrationType` redundancy.
96#[derive(Debug, Clone)]
97pub enum MigrationKind {
98    Versioned(MigrationVersion),
99    Repeatable,
100}
101
102/// A migration file discovered on disk.
103#[derive(Debug, Clone)]
104pub struct ResolvedMigration {
105    pub kind: MigrationKind,
106    pub description: String,
107    pub script: String,
108    pub checksum: i32,
109    pub sql: String,
110}
111
112impl ResolvedMigration {
113    /// Get the version if this is a versioned migration.
114    pub fn version(&self) -> Option<&MigrationVersion> {
115        match &self.kind {
116            MigrationKind::Versioned(v) => Some(v),
117            MigrationKind::Repeatable => None,
118        }
119    }
120
121    /// Get the migration type for display/serialization.
122    pub fn migration_type(&self) -> MigrationType {
123        match &self.kind {
124            MigrationKind::Versioned(_) => MigrationType::Versioned,
125            MigrationKind::Repeatable => MigrationType::Repeatable,
126        }
127    }
128
129    /// Whether this is a versioned migration.
130    pub fn is_versioned(&self) -> bool {
131        matches!(&self.kind, MigrationKind::Versioned(_))
132    }
133}
134
135/// Parse a migration filename into its components.
136///
137/// Expected patterns:
138///   V{version}__{description}.sql  — versioned migration
139///   R__{description}.sql           — repeatable migration
140pub fn parse_migration_filename(filename: &str) -> Result<(MigrationKind, String)> {
141    // Strip .sql extension
142    let stem = filename.strip_suffix(".sql").ok_or_else(|| {
143        WaypointError::MigrationParseError(format!(
144            "Migration file '{}' does not have .sql extension",
145            filename
146        ))
147    })?;
148
149    if let Some(caps) = VERSIONED_RE.captures(stem) {
150        let version_str = caps.get(1).unwrap().as_str();
151        let description = caps.get(2).unwrap().as_str().replace('_', " ");
152        let version = MigrationVersion::parse(version_str)?;
153        Ok((MigrationKind::Versioned(version), description))
154    } else if let Some(caps) = REPEATABLE_RE.captures(stem) {
155        let description = caps.get(1).unwrap().as_str().replace('_', " ");
156        Ok((MigrationKind::Repeatable, description))
157    } else {
158        Err(WaypointError::MigrationParseError(format!(
159            "Migration file '{}' does not match V{{version}}__{{description}}.sql or R__{{description}}.sql pattern",
160            filename
161        )))
162    }
163}
164
165/// Scan migration locations for SQL files and parse them into ResolvedMigrations.
166pub fn scan_migrations(locations: &[std::path::PathBuf]) -> Result<Vec<ResolvedMigration>> {
167    let mut migrations = Vec::new();
168
169    for location in locations {
170        if !location.exists() {
171            tracing::warn!("Migration location does not exist: {}", location.display());
172            continue;
173        }
174
175        let entries = std::fs::read_dir(location).map_err(|e| {
176            WaypointError::IoError(std::io::Error::new(
177                e.kind(),
178                format!(
179                    "Failed to read migration directory '{}': {}",
180                    location.display(),
181                    e
182                ),
183            ))
184        })?;
185
186        for entry in entries {
187            let entry = entry?;
188            let path = entry.path();
189
190            if !path.is_file() {
191                continue;
192            }
193
194            let filename = match path.file_name().and_then(|n| n.to_str()) {
195                Some(name) => name.to_string(),
196                None => continue,
197            };
198
199            // Skip non-SQL files
200            if !filename.ends_with(".sql") {
201                continue;
202            }
203
204            // Skip hook callback files
205            if hooks::is_hook_file(&filename) {
206                continue;
207            }
208
209            // Skip files that don't start with V or R
210            if !filename.starts_with('V') && !filename.starts_with('R') {
211                continue;
212            }
213
214            let (kind, description) = parse_migration_filename(&filename)?;
215            let sql = std::fs::read_to_string(&path)?;
216            let checksum = calculate_checksum(&sql);
217
218            migrations.push(ResolvedMigration {
219                kind,
220                description,
221                script: filename,
222                checksum,
223                sql,
224            });
225        }
226    }
227
228    // Sort: versioned by version, repeatable by description
229    migrations.sort_by(|a, b| match (&a.kind, &b.kind) {
230        (MigrationKind::Versioned(va), MigrationKind::Versioned(vb)) => va.cmp(vb),
231        (MigrationKind::Versioned(_), MigrationKind::Repeatable) => Ordering::Less,
232        (MigrationKind::Repeatable, MigrationKind::Versioned(_)) => Ordering::Greater,
233        (MigrationKind::Repeatable, MigrationKind::Repeatable) => a.description.cmp(&b.description),
234    });
235
236    Ok(migrations)
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_version_parsing() {
245        let v = MigrationVersion::parse("1").unwrap();
246        assert_eq!(v.segments, vec![1]);
247
248        let v = MigrationVersion::parse("1.2.3").unwrap();
249        assert_eq!(v.segments, vec![1, 2, 3]);
250
251        let v = MigrationVersion::parse("1_2_3").unwrap();
252        assert_eq!(v.segments, vec![1, 2, 3]);
253    }
254
255    #[test]
256    fn test_version_ordering() {
257        let v1 = MigrationVersion::parse("1").unwrap();
258        let v2 = MigrationVersion::parse("2").unwrap();
259        let v1_9 = MigrationVersion::parse("1.9").unwrap();
260        let v1_10 = MigrationVersion::parse("1.10").unwrap();
261        let v1_2 = MigrationVersion::parse("1.2").unwrap();
262        let v1_2_0 = MigrationVersion::parse("1.2.0").unwrap();
263
264        assert!(v1 < v2);
265        assert!(v1_9 < v1_10); // Numeric, not string comparison
266        assert!(v1_2 < v1_9);
267        assert_eq!(v1_2.cmp(&v1_2_0), Ordering::Equal); // Trailing zeros are equal
268    }
269
270    #[test]
271    fn test_version_parse_error() {
272        assert!(MigrationVersion::parse("").is_err());
273        assert!(MigrationVersion::parse("abc").is_err());
274    }
275
276    #[test]
277    fn test_parse_versioned_filename() {
278        let (kind, desc) = parse_migration_filename("V1__Create_users.sql").unwrap();
279        match kind {
280            MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1]),
281            _ => panic!("Expected Versioned"),
282        }
283        assert_eq!(desc, "Create users");
284    }
285
286    #[test]
287    fn test_parse_versioned_dotted_version() {
288        let (kind, desc) = parse_migration_filename("V1.2.3__Add_column.sql").unwrap();
289        match kind {
290            MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1, 2, 3]),
291            _ => panic!("Expected Versioned"),
292        }
293        assert_eq!(desc, "Add column");
294    }
295
296    #[test]
297    fn test_parse_repeatable_filename() {
298        let (kind, desc) = parse_migration_filename("R__Create_user_view.sql").unwrap();
299        assert!(matches!(kind, MigrationKind::Repeatable));
300        assert_eq!(desc, "Create user view");
301    }
302
303    #[test]
304    fn test_parse_invalid_filename() {
305        assert!(parse_migration_filename("random.sql").is_err());
306        assert!(parse_migration_filename("V1_missing_separator.sql").is_err());
307        assert!(parse_migration_filename("V1__no_ext").is_err());
308    }
309}