schema_installer/
migration.rs1use crate::error::SchemaInstallerError;
2use sha2::{Digest, Sha256};
3use std::path::PathBuf;
4
5#[derive(Clone)]
6pub struct Migration {
7 pub version: String,
8 pub description: String,
9 pub script_path: String,
10 pub sql: String,
11}
12
13#[derive(Debug, Clone)]
14pub struct AppliedMigration {
15 pub id: i64,
16 pub version: String,
17 pub script_path: String,
18 pub checksum: String,
19 pub execution_time_ms: i64,
20 pub installed_at: String,
21 pub status: String,
22 pub tool_version: String,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum MigrationStatus {
27 Success,
28 Failed,
29 Pending,
30}
31
32impl MigrationStatus {
33 pub fn as_str(&self) -> &'static str {
34 match self {
35 MigrationStatus::Success => "success",
36 MigrationStatus::Failed => "failed",
37 MigrationStatus::Pending => "pending",
38 }
39 }
40
41 pub fn from_str(s: &str) -> Option<Self> {
42 match s {
43 "success" => Some(MigrationStatus::Success),
44 "failed" => Some(MigrationStatus::Failed),
45 "pending" => Some(MigrationStatus::Pending),
46 _ => None,
47 }
48 }
49}
50
51pub trait MigrationSource: Send + Sync {
52 fn migrations(&self) -> Result<Vec<Migration>, SchemaInstallerError>;
53}
54
55pub struct DirectoryMigrationSource {
56 pub path: PathBuf,
57}
58
59impl MigrationSource for DirectoryMigrationSource {
60 fn migrations(&self) -> Result<Vec<Migration>, SchemaInstallerError> {
61 let mut migrations = Vec::new();
62
63 if !self.path.exists() {
64 return Err(SchemaInstallerError::InvalidConfiguration(
65 format!("Migrations directory does not exist: {:?}", self.path),
66 ));
67 }
68
69 if !self.path.is_dir() {
70 return Err(SchemaInstallerError::InvalidConfiguration(
71 format!("Migrations path is not a directory: {:?}", self.path),
72 ));
73 }
74
75 let entries = std::fs::read_dir(&self.path)
76 .map_err(|e| SchemaInstallerError::Io(e))?;
77
78 for entry in entries {
79 let entry = entry.map_err(|e| SchemaInstallerError::Io(e))?;
80 let path = entry.path();
81
82 if !path.is_file() {
83 continue;
84 }
85
86 let filename = path
87 .file_name()
88 .and_then(|f| f.to_str())
89 .ok_or_else(|| {
90 SchemaInstallerError::InvalidConfiguration(
91 "Invalid filename encoding".to_string(),
92 )
93 })?;
94
95 if !filename.to_lowercase().ends_with(".sql") {
96 continue;
97 }
98
99 let (version, description) = parse_migration_filename(filename)?;
100 let sql = std::fs::read_to_string(&path)
101 .map_err(|e| SchemaInstallerError::Io(e))?;
102
103 let script_path = path.to_string_lossy().to_string();
104
105 migrations.push(Migration {
106 version,
107 description,
108 script_path,
109 sql,
110 });
111 }
112
113 migrations.sort_by(|a, b| compare_versions(&a.version, &b.version));
114
115 Ok(migrations)
116 }
117}
118
119pub struct EmbeddedMigrationSource {
120 pub migrations: Vec<Migration>,
121}
122
123impl MigrationSource for EmbeddedMigrationSource {
124 fn migrations(&self) -> Result<Vec<Migration>, SchemaInstallerError> {
125 Ok(self.migrations.clone())
126 }
127}
128
129fn parse_migration_filename(filename: &str) -> Result<(String, String), SchemaInstallerError> {
130 let name_without_ext = filename
131 .strip_suffix(".sql")
132 .ok_or_else(|| {
133 SchemaInstallerError::InvalidConfiguration(
134 format!("File does not end with .sql: {}", filename),
135 )
136 })?;
137
138 let parts: Vec<&str> = name_without_ext.splitn(2, "__").collect();
139
140 if parts.len() != 2 {
141 return Err(SchemaInstallerError::InvalidConfiguration(
142 format!(
143 "Invalid migration filename format (expected V{{version}}__{{description}}.sql): {}",
144 filename
145 ),
146 ));
147 }
148
149 let version_part = parts[0].to_lowercase();
150 if !version_part.starts_with('v') {
151 return Err(SchemaInstallerError::InvalidConfiguration(
152 format!(
153 "Migration filename must start with V (case-insensitive): {}",
154 filename
155 ),
156 ));
157 }
158
159 let version = version_part[1..].to_string();
160 let description = parts[1].replace('_', " ");
161
162 if version.is_empty() {
163 return Err(SchemaInstallerError::InvalidConfiguration(
164 format!("Migration version cannot be empty: {}", filename),
165 ));
166 }
167
168 Ok((version, description))
169}
170
171pub fn compare_versions(v1: &str, v2: &str) -> std::cmp::Ordering {
172 let v1_parts: Vec<u64> = v1
173 .split('.')
174 .filter_map(|p| p.parse::<u64>().ok())
175 .collect();
176 let v2_parts: Vec<u64> = v2
177 .split('.')
178 .filter_map(|p| p.parse::<u64>().ok())
179 .collect();
180
181 for (p1, p2) in v1_parts.iter().zip(v2_parts.iter()) {
182 if p1 != p2 {
183 return p1.cmp(p2);
184 }
185 }
186
187 v1_parts.len().cmp(&v2_parts.len())
188}
189
190pub fn compute_checksum(sql: &str) -> String {
191 let normalized = sql.trim().replace("\r\n", "\n");
192 let mut hasher = Sha256::new();
193 hasher.update(normalized.as_bytes());
194 let result = hasher.finalize();
195 hex::encode(result)
196}
197
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_parse_migration_filename() {
205 let (version, description) = parse_migration_filename("V1__create_users.sql").unwrap();
206 assert_eq!(version, "1");
207 assert_eq!(description, "create users");
208
209 let (version, description) = parse_migration_filename("V1_2__add_email_column.sql").unwrap();
210 assert_eq!(version, "1_2");
211 assert_eq!(description, "add email column");
212 }
213
214 #[test]
215 fn test_parse_migration_filename_case_insensitive() {
216 let (version, description) = parse_migration_filename("v1__create_users.sql").unwrap();
217 assert_eq!(version, "1");
218 assert_eq!(description, "create users");
219 }
220
221 #[test]
222 fn test_version_comparison() {
223 assert!(compare_versions("1", "2") == std::cmp::Ordering::Less);
224 assert!(compare_versions("2", "1") == std::cmp::Ordering::Greater);
225 assert!(compare_versions("1", "1") == std::cmp::Ordering::Equal);
226 assert!(compare_versions("1.2", "1.3") == std::cmp::Ordering::Less);
227 assert!(compare_versions("1.10", "1.2") == std::cmp::Ordering::Greater);
228 }
229
230 #[test]
231 fn test_compute_checksum() {
232 let sql = "CREATE TABLE users (id BIGSERIAL PRIMARY KEY);";
233 let checksum1 = compute_checksum(sql);
234 let checksum2 = compute_checksum(sql);
235 assert_eq!(checksum1, checksum2);
236
237 let checksum3 = compute_checksum("CREATE TABLE posts (id BIGSERIAL PRIMARY KEY);");
238 assert_ne!(checksum1, checksum3);
239 }
240
241 #[test]
242 fn test_compute_checksum_normalizes_whitespace() {
243 let sql1 = "CREATE TABLE users (id BIGSERIAL PRIMARY KEY);";
244 let sql2 = "CREATE TABLE users (id BIGSERIAL PRIMARY KEY);\n";
245 let sql3 = "CREATE TABLE users (\n id BIGSERIAL PRIMARY KEY\n);";
246
247 let checksum1 = compute_checksum(sql1);
248 let checksum2 = compute_checksum(sql2);
249 let checksum3 = compute_checksum(sql3);
250
251 assert_eq!(checksum1, checksum2);
252 assert_ne!(checksum1, checksum3);
253 }
254}