1use std::cmp::Ordering;
6use std::fmt;
7use std::sync::LazyLock;
8
9use regex::Regex;
10
11use crate::checksum::calculate_checksum;
12use crate::error::{Result, WaypointError};
13use crate::hooks;
14
15static VERSIONED_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^V([\d._]+)__(.+)$").unwrap());
16static REPEATABLE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^R__(.+)$").unwrap());
17
18#[derive(Debug, Clone, Eq, PartialEq)]
20pub struct MigrationVersion {
21 pub segments: Vec<u64>,
22 pub raw: String,
23}
24
25impl MigrationVersion {
26 pub fn parse(raw: &str) -> Result<Self> {
28 if raw.is_empty() {
29 return Err(WaypointError::MigrationParseError(
30 "Version string is empty".to_string(),
31 ));
32 }
33
34 let segments: std::result::Result<Vec<u64>, _> =
36 raw.split(['.', '_']).map(|s| s.parse::<u64>()).collect();
37
38 let segments = segments.map_err(|e| {
39 WaypointError::MigrationParseError(format!(
40 "Invalid version segment in '{}': {}",
41 raw, e
42 ))
43 })?;
44
45 Ok(MigrationVersion {
46 segments,
47 raw: raw.to_string(),
48 })
49 }
50}
51
52impl Ord for MigrationVersion {
53 fn cmp(&self, other: &Self) -> Ordering {
54 let max_len = self.segments.len().max(other.segments.len());
55 for i in 0..max_len {
56 let a = self.segments.get(i).copied().unwrap_or(0);
57 let b = other.segments.get(i).copied().unwrap_or(0);
58 match a.cmp(&b) {
59 Ordering::Equal => continue,
60 ord => return ord,
61 }
62 }
63 Ordering::Equal
64 }
65}
66
67impl PartialOrd for MigrationVersion {
68 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
69 Some(self.cmp(other))
70 }
71}
72
73impl fmt::Display for MigrationVersion {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 write!(f, "{}", self.raw)
76 }
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
81pub enum MigrationType {
82 Versioned,
84 Repeatable,
86}
87
88impl fmt::Display for MigrationType {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self {
91 MigrationType::Versioned => write!(f, "SQL"),
92 MigrationType::Repeatable => write!(f, "SQL_REPEATABLE"),
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
102pub enum MigrationKind {
103 Versioned(MigrationVersion),
104 Repeatable,
105}
106
107#[derive(Debug, Clone)]
109pub struct ResolvedMigration {
110 pub kind: MigrationKind,
111 pub description: String,
112 pub script: String,
113 pub checksum: i32,
114 pub sql: String,
115}
116
117impl ResolvedMigration {
118 pub fn version(&self) -> Option<&MigrationVersion> {
120 match &self.kind {
121 MigrationKind::Versioned(v) => Some(v),
122 MigrationKind::Repeatable => None,
123 }
124 }
125
126 pub fn migration_type(&self) -> MigrationType {
128 match &self.kind {
129 MigrationKind::Versioned(_) => MigrationType::Versioned,
130 MigrationKind::Repeatable => MigrationType::Repeatable,
131 }
132 }
133
134 pub fn is_versioned(&self) -> bool {
136 matches!(&self.kind, MigrationKind::Versioned(_))
137 }
138}
139
140pub fn parse_migration_filename(filename: &str) -> Result<(MigrationKind, String)> {
146 let stem = filename.strip_suffix(".sql").ok_or_else(|| {
148 WaypointError::MigrationParseError(format!(
149 "Migration file '{}' does not have .sql extension",
150 filename
151 ))
152 })?;
153
154 if let Some(caps) = VERSIONED_RE.captures(stem) {
155 let version_str = caps.get(1).unwrap().as_str();
156 let description = caps.get(2).unwrap().as_str().replace('_', " ");
157 let version = MigrationVersion::parse(version_str)?;
158 Ok((MigrationKind::Versioned(version), description))
159 } else if let Some(caps) = REPEATABLE_RE.captures(stem) {
160 let description = caps.get(1).unwrap().as_str().replace('_', " ");
161 Ok((MigrationKind::Repeatable, description))
162 } else {
163 Err(WaypointError::MigrationParseError(format!(
164 "Migration file '{}' does not match V{{version}}__{{description}}.sql or R__{{description}}.sql pattern",
165 filename
166 )))
167 }
168}
169
170pub fn scan_migrations(locations: &[std::path::PathBuf]) -> Result<Vec<ResolvedMigration>> {
172 let mut migrations = Vec::new();
173
174 for location in locations {
175 if !location.exists() {
176 tracing::warn!("Migration location does not exist: {}", location.display());
177 continue;
178 }
179
180 let entries = std::fs::read_dir(location).map_err(|e| {
181 WaypointError::IoError(std::io::Error::new(
182 e.kind(),
183 format!(
184 "Failed to read migration directory '{}': {}",
185 location.display(),
186 e
187 ),
188 ))
189 })?;
190
191 for entry in entries {
192 let entry = entry?;
193 let path = entry.path();
194
195 if !path.is_file() {
196 continue;
197 }
198
199 let filename = match path.file_name().and_then(|n| n.to_str()) {
200 Some(name) => name.to_string(),
201 None => continue,
202 };
203
204 if !filename.ends_with(".sql") {
206 continue;
207 }
208
209 if hooks::is_hook_file(&filename) {
211 continue;
212 }
213
214 if !filename.starts_with('V') && !filename.starts_with('R') {
216 continue;
217 }
218
219 let (kind, description) = parse_migration_filename(&filename)?;
220 let sql = std::fs::read_to_string(&path)?;
221 let checksum = calculate_checksum(&sql);
222
223 migrations.push(ResolvedMigration {
224 kind,
225 description,
226 script: filename,
227 checksum,
228 sql,
229 });
230 }
231 }
232
233 migrations.sort_by(|a, b| match (&a.kind, &b.kind) {
235 (MigrationKind::Versioned(va), MigrationKind::Versioned(vb)) => va.cmp(vb),
236 (MigrationKind::Versioned(_), MigrationKind::Repeatable) => Ordering::Less,
237 (MigrationKind::Repeatable, MigrationKind::Versioned(_)) => Ordering::Greater,
238 (MigrationKind::Repeatable, MigrationKind::Repeatable) => a.description.cmp(&b.description),
239 });
240
241 Ok(migrations)
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_version_parsing() {
250 let v = MigrationVersion::parse("1").unwrap();
251 assert_eq!(v.segments, vec![1]);
252
253 let v = MigrationVersion::parse("1.2.3").unwrap();
254 assert_eq!(v.segments, vec![1, 2, 3]);
255
256 let v = MigrationVersion::parse("1_2_3").unwrap();
257 assert_eq!(v.segments, vec![1, 2, 3]);
258 }
259
260 #[test]
261 fn test_version_ordering() {
262 let v1 = MigrationVersion::parse("1").unwrap();
263 let v2 = MigrationVersion::parse("2").unwrap();
264 let v1_9 = MigrationVersion::parse("1.9").unwrap();
265 let v1_10 = MigrationVersion::parse("1.10").unwrap();
266 let v1_2 = MigrationVersion::parse("1.2").unwrap();
267 let v1_2_0 = MigrationVersion::parse("1.2.0").unwrap();
268
269 assert!(v1 < v2);
270 assert!(v1_9 < v1_10); assert!(v1_2 < v1_9);
272 assert_eq!(v1_2.cmp(&v1_2_0), Ordering::Equal); }
274
275 #[test]
276 fn test_version_parse_error() {
277 assert!(MigrationVersion::parse("").is_err());
278 assert!(MigrationVersion::parse("abc").is_err());
279 }
280
281 #[test]
282 fn test_parse_versioned_filename() {
283 let (kind, desc) = parse_migration_filename("V1__Create_users.sql").unwrap();
284 match kind {
285 MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1]),
286 _ => panic!("Expected Versioned"),
287 }
288 assert_eq!(desc, "Create users");
289 }
290
291 #[test]
292 fn test_parse_versioned_dotted_version() {
293 let (kind, desc) = parse_migration_filename("V1.2.3__Add_column.sql").unwrap();
294 match kind {
295 MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1, 2, 3]),
296 _ => panic!("Expected Versioned"),
297 }
298 assert_eq!(desc, "Add column");
299 }
300
301 #[test]
302 fn test_parse_repeatable_filename() {
303 let (kind, desc) = parse_migration_filename("R__Create_user_view.sql").unwrap();
304 assert!(matches!(kind, MigrationKind::Repeatable));
305 assert_eq!(desc, "Create user view");
306 }
307
308 #[test]
309 fn test_parse_invalid_filename() {
310 assert!(parse_migration_filename("random.sql").is_err());
311 assert!(parse_migration_filename("V1_missing_separator.sql").is_err());
312 assert!(parse_migration_filename("V1__no_ext").is_err());
313 }
314}