Skip to main content

schema_installer/
migration.rs

1use crate::error::SchemaInstallerError;
2use sha2::{Digest, Sha256};
3use std::path::PathBuf;
4
5#[derive(Clone)]
6pub struct Migration {
7    pub version: String,
8    pub description: String,
9    pub script_path: String,
10    pub sql: String,
11}
12
13#[derive(Debug, Clone)]
14pub struct AppliedMigration {
15    pub id: i64,
16    pub version: String,
17    pub script_path: String,
18    pub checksum: String,
19    pub execution_time_ms: i64,
20    pub installed_at: String,
21    pub status: String,
22    pub tool_version: String,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum MigrationStatus {
27    Success,
28    Failed,
29    Pending,
30}
31
32impl MigrationStatus {
33    pub fn as_str(&self) -> &'static str {
34        match self {
35            MigrationStatus::Success => "success",
36            MigrationStatus::Failed => "failed",
37            MigrationStatus::Pending => "pending",
38        }
39    }
40
41    pub fn from_str(s: &str) -> Option<Self> {
42        match s {
43            "success" => Some(MigrationStatus::Success),
44            "failed" => Some(MigrationStatus::Failed),
45            "pending" => Some(MigrationStatus::Pending),
46            _ => None,
47        }
48    }
49}
50
51pub trait MigrationSource: Send + Sync {
52    fn migrations(&self) -> Result<Vec<Migration>, SchemaInstallerError>;
53}
54
55pub struct DirectoryMigrationSource {
56    pub path: PathBuf,
57}
58
59impl MigrationSource for DirectoryMigrationSource {
60    fn migrations(&self) -> Result<Vec<Migration>, SchemaInstallerError> {
61        let mut migrations = Vec::new();
62
63        if !self.path.exists() {
64            return Err(SchemaInstallerError::InvalidConfiguration(
65                format!("Migrations directory does not exist: {:?}", self.path),
66            ));
67        }
68
69        if !self.path.is_dir() {
70            return Err(SchemaInstallerError::InvalidConfiguration(
71                format!("Migrations path is not a directory: {:?}", self.path),
72            ));
73        }
74
75        let entries = std::fs::read_dir(&self.path)
76            .map_err(|e| SchemaInstallerError::Io(e))?;
77
78        for entry in entries {
79            let entry = entry.map_err(|e| SchemaInstallerError::Io(e))?;
80            let path = entry.path();
81
82            if !path.is_file() {
83                continue;
84            }
85
86            let filename = path
87                .file_name()
88                .and_then(|f| f.to_str())
89                .ok_or_else(|| {
90                    SchemaInstallerError::InvalidConfiguration(
91                        "Invalid filename encoding".to_string(),
92                    )
93                })?;
94
95            if !filename.to_lowercase().ends_with(".sql") {
96                continue;
97            }
98
99            let (version, description) = parse_migration_filename(filename)?;
100            let sql = std::fs::read_to_string(&path)
101                .map_err(|e| SchemaInstallerError::Io(e))?;
102
103            let script_path = path.to_string_lossy().to_string();
104
105            migrations.push(Migration {
106                version,
107                description,
108                script_path,
109                sql,
110            });
111        }
112
113        migrations.sort_by(|a, b| compare_versions(&a.version, &b.version));
114
115        Ok(migrations)
116    }
117}
118
119pub struct EmbeddedMigrationSource {
120    pub migrations: Vec<Migration>,
121}
122
123impl MigrationSource for EmbeddedMigrationSource {
124    fn migrations(&self) -> Result<Vec<Migration>, SchemaInstallerError> {
125        Ok(self.migrations.clone())
126    }
127}
128
129fn parse_migration_filename(filename: &str) -> Result<(String, String), SchemaInstallerError> {
130    let name_without_ext = filename
131        .strip_suffix(".sql")
132        .ok_or_else(|| {
133            SchemaInstallerError::InvalidConfiguration(
134                format!("File does not end with .sql: {}", filename),
135            )
136        })?;
137
138    let parts: Vec<&str> = name_without_ext.splitn(2, "__").collect();
139
140    if parts.len() != 2 {
141        return Err(SchemaInstallerError::InvalidConfiguration(
142            format!(
143                "Invalid migration filename format (expected V{{version}}__{{description}}.sql): {}",
144                filename
145            ),
146        ));
147    }
148
149    let version_part = parts[0].to_lowercase();
150    if !version_part.starts_with('v') {
151        return Err(SchemaInstallerError::InvalidConfiguration(
152            format!(
153                "Migration filename must start with V (case-insensitive): {}",
154                filename
155            ),
156        ));
157    }
158
159    let version = version_part[1..].to_string();
160    let description = parts[1].replace('_', " ");
161
162    if version.is_empty() {
163        return Err(SchemaInstallerError::InvalidConfiguration(
164            format!("Migration version cannot be empty: {}", filename),
165        ));
166    }
167
168    Ok((version, description))
169}
170
171pub fn compare_versions(v1: &str, v2: &str) -> std::cmp::Ordering {
172    let v1_parts: Vec<u64> = v1
173        .split('.')
174        .filter_map(|p| p.parse::<u64>().ok())
175        .collect();
176    let v2_parts: Vec<u64> = v2
177        .split('.')
178        .filter_map(|p| p.parse::<u64>().ok())
179        .collect();
180
181    for (p1, p2) in v1_parts.iter().zip(v2_parts.iter()) {
182        if p1 != p2 {
183            return p1.cmp(p2);
184        }
185    }
186
187    v1_parts.len().cmp(&v2_parts.len())
188}
189
190pub fn compute_checksum(sql: &str) -> String {
191    let normalized = sql.trim().replace("\r\n", "\n");
192    let mut hasher = Sha256::new();
193    hasher.update(normalized.as_bytes());
194    let result = hasher.finalize();
195    hex::encode(result)
196}
197
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_parse_migration_filename() {
205        let (version, description) = parse_migration_filename("V1__create_users.sql").unwrap();
206        assert_eq!(version, "1");
207        assert_eq!(description, "create users");
208
209        let (version, description) = parse_migration_filename("V1_2__add_email_column.sql").unwrap();
210        assert_eq!(version, "1_2");
211        assert_eq!(description, "add email column");
212    }
213
214    #[test]
215    fn test_parse_migration_filename_case_insensitive() {
216        let (version, description) = parse_migration_filename("v1__create_users.sql").unwrap();
217        assert_eq!(version, "1");
218        assert_eq!(description, "create users");
219    }
220
221    #[test]
222    fn test_version_comparison() {
223        assert!(compare_versions("1", "2") == std::cmp::Ordering::Less);
224        assert!(compare_versions("2", "1") == std::cmp::Ordering::Greater);
225        assert!(compare_versions("1", "1") == std::cmp::Ordering::Equal);
226        assert!(compare_versions("1.2", "1.3") == std::cmp::Ordering::Less);
227        assert!(compare_versions("1.10", "1.2") == std::cmp::Ordering::Greater);
228    }
229
230    #[test]
231    fn test_compute_checksum() {
232        let sql = "CREATE TABLE users (id BIGSERIAL PRIMARY KEY);";
233        let checksum1 = compute_checksum(sql);
234        let checksum2 = compute_checksum(sql);
235        assert_eq!(checksum1, checksum2);
236
237        let checksum3 = compute_checksum("CREATE TABLE posts (id BIGSERIAL PRIMARY KEY);");
238        assert_ne!(checksum1, checksum3);
239    }
240
241    #[test]
242    fn test_compute_checksum_normalizes_whitespace() {
243        let sql1 = "CREATE TABLE users (id BIGSERIAL PRIMARY KEY);";
244        let sql2 = "CREATE TABLE users (id BIGSERIAL PRIMARY KEY);\n";
245        let sql3 = "CREATE TABLE users (\n  id BIGSERIAL PRIMARY KEY\n);";
246
247        let checksum1 = compute_checksum(sql1);
248        let checksum2 = compute_checksum(sql2);
249        let checksum3 = compute_checksum(sql3);
250
251        assert_eq!(checksum1, checksum2);
252        assert_ne!(checksum1, checksum3);
253    }
254}