1use std::cmp::Ordering;
2use std::fmt;
3use std::sync::LazyLock;
4
5use regex::Regex;
6
7use crate::checksum::calculate_checksum;
8use crate::error::{Result, WaypointError};
9use crate::hooks;
10
11static VERSIONED_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^V([\d._]+)__(.+)$").unwrap());
12static REPEATABLE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^R__(.+)$").unwrap());
13
14#[derive(Debug, Clone, Eq, PartialEq)]
16pub struct MigrationVersion {
17 pub segments: Vec<u64>,
18 pub raw: String,
19}
20
21impl MigrationVersion {
22 pub fn parse(raw: &str) -> Result<Self> {
23 if raw.is_empty() {
24 return Err(WaypointError::MigrationParseError(
25 "Version string is empty".to_string(),
26 ));
27 }
28
29 let segments: std::result::Result<Vec<u64>, _> =
31 raw.split(['.', '_']).map(|s| s.parse::<u64>()).collect();
32
33 let segments = segments.map_err(|e| {
34 WaypointError::MigrationParseError(format!(
35 "Invalid version segment in '{}': {}",
36 raw, e
37 ))
38 })?;
39
40 Ok(MigrationVersion {
41 segments,
42 raw: raw.to_string(),
43 })
44 }
45}
46
47impl Ord for MigrationVersion {
48 fn cmp(&self, other: &Self) -> Ordering {
49 let max_len = self.segments.len().max(other.segments.len());
50 for i in 0..max_len {
51 let a = self.segments.get(i).copied().unwrap_or(0);
52 let b = other.segments.get(i).copied().unwrap_or(0);
53 match a.cmp(&b) {
54 Ordering::Equal => continue,
55 ord => return ord,
56 }
57 }
58 Ordering::Equal
59 }
60}
61
62impl PartialOrd for MigrationVersion {
63 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64 Some(self.cmp(other))
65 }
66}
67
68impl fmt::Display for MigrationVersion {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 write!(f, "{}", self.raw)
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum MigrationType {
77 Versioned,
79 Repeatable,
81}
82
83impl fmt::Display for MigrationType {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 match self {
86 MigrationType::Versioned => write!(f, "SQL"),
87 MigrationType::Repeatable => write!(f, "SQL_REPEATABLE"),
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
97pub enum MigrationKind {
98 Versioned(MigrationVersion),
99 Repeatable,
100}
101
102#[derive(Debug, Clone)]
104pub struct ResolvedMigration {
105 pub kind: MigrationKind,
106 pub description: String,
107 pub script: String,
108 pub checksum: i32,
109 pub sql: String,
110}
111
112impl ResolvedMigration {
113 pub fn version(&self) -> Option<&MigrationVersion> {
115 match &self.kind {
116 MigrationKind::Versioned(v) => Some(v),
117 MigrationKind::Repeatable => None,
118 }
119 }
120
121 pub fn migration_type(&self) -> MigrationType {
123 match &self.kind {
124 MigrationKind::Versioned(_) => MigrationType::Versioned,
125 MigrationKind::Repeatable => MigrationType::Repeatable,
126 }
127 }
128
129 pub fn is_versioned(&self) -> bool {
131 matches!(&self.kind, MigrationKind::Versioned(_))
132 }
133}
134
135pub fn parse_migration_filename(filename: &str) -> Result<(MigrationKind, String)> {
141 let stem = filename.strip_suffix(".sql").ok_or_else(|| {
143 WaypointError::MigrationParseError(format!(
144 "Migration file '{}' does not have .sql extension",
145 filename
146 ))
147 })?;
148
149 if let Some(caps) = VERSIONED_RE.captures(stem) {
150 let version_str = caps.get(1).unwrap().as_str();
151 let description = caps.get(2).unwrap().as_str().replace('_', " ");
152 let version = MigrationVersion::parse(version_str)?;
153 Ok((MigrationKind::Versioned(version), description))
154 } else if let Some(caps) = REPEATABLE_RE.captures(stem) {
155 let description = caps.get(1).unwrap().as_str().replace('_', " ");
156 Ok((MigrationKind::Repeatable, description))
157 } else {
158 Err(WaypointError::MigrationParseError(format!(
159 "Migration file '{}' does not match V{{version}}__{{description}}.sql or R__{{description}}.sql pattern",
160 filename
161 )))
162 }
163}
164
165pub fn scan_migrations(locations: &[std::path::PathBuf]) -> Result<Vec<ResolvedMigration>> {
167 let mut migrations = Vec::new();
168
169 for location in locations {
170 if !location.exists() {
171 tracing::warn!("Migration location does not exist: {}", location.display());
172 continue;
173 }
174
175 let entries = std::fs::read_dir(location).map_err(|e| {
176 WaypointError::IoError(std::io::Error::new(
177 e.kind(),
178 format!(
179 "Failed to read migration directory '{}': {}",
180 location.display(),
181 e
182 ),
183 ))
184 })?;
185
186 for entry in entries {
187 let entry = entry?;
188 let path = entry.path();
189
190 if !path.is_file() {
191 continue;
192 }
193
194 let filename = match path.file_name().and_then(|n| n.to_str()) {
195 Some(name) => name.to_string(),
196 None => continue,
197 };
198
199 if !filename.ends_with(".sql") {
201 continue;
202 }
203
204 if hooks::is_hook_file(&filename) {
206 continue;
207 }
208
209 if !filename.starts_with('V') && !filename.starts_with('R') {
211 continue;
212 }
213
214 let (kind, description) = parse_migration_filename(&filename)?;
215 let sql = std::fs::read_to_string(&path)?;
216 let checksum = calculate_checksum(&sql);
217
218 migrations.push(ResolvedMigration {
219 kind,
220 description,
221 script: filename,
222 checksum,
223 sql,
224 });
225 }
226 }
227
228 migrations.sort_by(|a, b| match (&a.kind, &b.kind) {
230 (MigrationKind::Versioned(va), MigrationKind::Versioned(vb)) => va.cmp(vb),
231 (MigrationKind::Versioned(_), MigrationKind::Repeatable) => Ordering::Less,
232 (MigrationKind::Repeatable, MigrationKind::Versioned(_)) => Ordering::Greater,
233 (MigrationKind::Repeatable, MigrationKind::Repeatable) => a.description.cmp(&b.description),
234 });
235
236 Ok(migrations)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_version_parsing() {
245 let v = MigrationVersion::parse("1").unwrap();
246 assert_eq!(v.segments, vec![1]);
247
248 let v = MigrationVersion::parse("1.2.3").unwrap();
249 assert_eq!(v.segments, vec![1, 2, 3]);
250
251 let v = MigrationVersion::parse("1_2_3").unwrap();
252 assert_eq!(v.segments, vec![1, 2, 3]);
253 }
254
255 #[test]
256 fn test_version_ordering() {
257 let v1 = MigrationVersion::parse("1").unwrap();
258 let v2 = MigrationVersion::parse("2").unwrap();
259 let v1_9 = MigrationVersion::parse("1.9").unwrap();
260 let v1_10 = MigrationVersion::parse("1.10").unwrap();
261 let v1_2 = MigrationVersion::parse("1.2").unwrap();
262 let v1_2_0 = MigrationVersion::parse("1.2.0").unwrap();
263
264 assert!(v1 < v2);
265 assert!(v1_9 < v1_10); assert!(v1_2 < v1_9);
267 assert_eq!(v1_2.cmp(&v1_2_0), Ordering::Equal); }
269
270 #[test]
271 fn test_version_parse_error() {
272 assert!(MigrationVersion::parse("").is_err());
273 assert!(MigrationVersion::parse("abc").is_err());
274 }
275
276 #[test]
277 fn test_parse_versioned_filename() {
278 let (kind, desc) = parse_migration_filename("V1__Create_users.sql").unwrap();
279 match kind {
280 MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1]),
281 _ => panic!("Expected Versioned"),
282 }
283 assert_eq!(desc, "Create users");
284 }
285
286 #[test]
287 fn test_parse_versioned_dotted_version() {
288 let (kind, desc) = parse_migration_filename("V1.2.3__Add_column.sql").unwrap();
289 match kind {
290 MigrationKind::Versioned(v) => assert_eq!(v.segments, vec![1, 2, 3]),
291 _ => panic!("Expected Versioned"),
292 }
293 assert_eq!(desc, "Add column");
294 }
295
296 #[test]
297 fn test_parse_repeatable_filename() {
298 let (kind, desc) = parse_migration_filename("R__Create_user_view.sql").unwrap();
299 assert!(matches!(kind, MigrationKind::Repeatable));
300 assert_eq!(desc, "Create user view");
301 }
302
303 #[test]
304 fn test_parse_invalid_filename() {
305 assert!(parse_migration_filename("random.sql").is_err());
306 assert!(parse_migration_filename("V1_missing_separator.sql").is_err());
307 assert!(parse_migration_filename("V1__no_ext").is_err());
308 }
309}