Skip to main content

shelly_data/
migration.rs

1use crate::{
2    adapter::AdapterKind,
3    error::{DataError, DataResult},
4};
5use serde::{Deserialize, Serialize};
6use std::{
7    fs,
8    path::{Path, PathBuf},
9    time::SystemTime,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct Migration {
14    pub id: String,
15    pub name: String,
16    pub up_sql: String,
17    pub down_sql: String,
18    pub path: PathBuf,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub struct AppliedMigration {
23    pub id: String,
24    pub name: String,
25    pub applied_at_unix_ms: u64,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Default)]
29pub struct MigrationStatus {
30    pub applied: Vec<AppliedMigration>,
31    pub pending: Vec<Migration>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct MigrationState {
36    adapter: AdapterKind,
37    applied: Vec<AppliedMigration>,
38}
39
40pub struct MigrationEngine {
41    project_root: PathBuf,
42    adapter: AdapterKind,
43}
44
45impl MigrationEngine {
46    pub fn new(project_root: impl Into<PathBuf>, adapter: AdapterKind) -> DataResult<Self> {
47        if adapter == AdapterKind::None {
48            return Err(DataError::Migration(
49                "adapter is `none`; choose postgres/mysql/sqlite before running migrations"
50                    .to_string(),
51            ));
52        }
53        Ok(Self {
54            project_root: project_root.into(),
55            adapter,
56        })
57    }
58
59    pub fn status(&self, all_migrations: &[Migration]) -> DataResult<MigrationStatus> {
60        let state = self.read_state()?;
61        let pending = all_migrations
62            .iter()
63            .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
64            .cloned()
65            .collect::<Vec<_>>();
66        Ok(MigrationStatus {
67            applied: state.applied,
68            pending,
69        })
70    }
71
72    pub fn migrate(
73        &self,
74        all_migrations: &[Migration],
75        steps: Option<usize>,
76    ) -> DataResult<Vec<AppliedMigration>> {
77        let mut state = self.read_state()?;
78        let mut applied_now = Vec::<AppliedMigration>::new();
79
80        let pending = all_migrations
81            .iter()
82            .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
83            .cloned()
84            .collect::<Vec<_>>();
85        let pending = if let Some(steps) = steps {
86            pending.into_iter().take(steps).collect::<Vec<_>>()
87        } else {
88            pending
89        };
90
91        for migration in pending {
92            if migration.up_sql.trim().is_empty() {
93                return Err(DataError::Migration(format!(
94                    "migration {} has empty up SQL",
95                    migration.path.display()
96                )));
97            }
98            let applied = AppliedMigration {
99                id: migration.id,
100                name: migration.name,
101                applied_at_unix_ms: now_unix_ms(),
102            };
103            state.applied.push(applied.clone());
104            applied_now.push(applied);
105        }
106
107        self.write_state(&state)?;
108        Ok(applied_now)
109    }
110
111    pub fn rollback(
112        &self,
113        all_migrations: &[Migration],
114        steps: usize,
115    ) -> DataResult<Vec<AppliedMigration>> {
116        let mut state = self.read_state()?;
117        let mut rolled_back = Vec::<AppliedMigration>::new();
118        let steps = steps.max(1);
119
120        for _ in 0..steps {
121            let Some(last) = state.applied.pop() else {
122                break;
123            };
124            let Some(definition) = all_migrations
125                .iter()
126                .find(|migration| migration.id == last.id)
127            else {
128                return Err(DataError::Migration(format!(
129                    "cannot rollback migration `{}` because file is missing",
130                    last.id
131                )));
132            };
133            if definition.down_sql.trim().is_empty() {
134                return Err(DataError::Migration(format!(
135                    "migration {} has empty down SQL",
136                    definition.path.display()
137                )));
138            }
139            rolled_back.push(last);
140        }
141
142        self.write_state(&state)?;
143        Ok(rolled_back)
144    }
145
146    fn state_path(&self) -> PathBuf {
147        self.project_root
148            .join(".shelly")
149            .join("migrations")
150            .join(format!("{}.json", self.adapter.as_str()))
151    }
152
153    fn read_state(&self) -> DataResult<MigrationState> {
154        let state_path = self.state_path();
155        if !state_path.exists() {
156            return Ok(MigrationState {
157                adapter: self.adapter,
158                applied: Vec::new(),
159            });
160        }
161        let raw = fs::read_to_string(state_path)?;
162        let mut state: MigrationState = serde_json::from_str(&raw)?;
163        state.adapter = self.adapter;
164        Ok(state)
165    }
166
167    fn write_state(&self, state: &MigrationState) -> DataResult<()> {
168        let state_path = self.state_path();
169        if let Some(parent) = state_path.parent() {
170            fs::create_dir_all(parent)?;
171        }
172        let body = serde_json::to_string_pretty(state)?;
173        fs::write(state_path, format!("{body}\n"))?;
174        Ok(())
175    }
176}
177
178pub fn load_migrations(dir: &Path) -> DataResult<Vec<Migration>> {
179    if !dir.exists() {
180        return Ok(Vec::new());
181    }
182
183    let mut entries = fs::read_dir(dir)?
184        .filter_map(|entry| entry.ok())
185        .map(|entry| entry.path())
186        .filter(|path| path.extension().is_some_and(|extension| extension == "sql"))
187        .collect::<Vec<_>>();
188    entries.sort();
189
190    let mut migrations = Vec::with_capacity(entries.len());
191    for path in entries {
192        let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else {
193            continue;
194        };
195        let Some((id, name)) = parse_file_id_name(file_name) else {
196            continue;
197        };
198        let source = fs::read_to_string(&path)?;
199        let (up_sql, down_sql) = parse_up_down(&source, &path)?;
200        migrations.push(Migration {
201            id,
202            name,
203            up_sql,
204            down_sql,
205            path,
206        });
207    }
208    Ok(migrations)
209}
210
211fn parse_file_id_name(file_name: &str) -> Option<(String, String)> {
212    let trimmed = file_name.strip_suffix(".sql")?;
213    let (id, name) = trimmed.split_once('_')?;
214    Some((id.to_string(), name.to_string()))
215}
216
217fn parse_up_down(source: &str, path: &Path) -> DataResult<(String, String)> {
218    let up_marker = "-- +up";
219    let down_marker = "-- +down";
220    let Some(up_start) = source.find(up_marker) else {
221        return Err(DataError::Migration(format!(
222            "migration {} missing `-- +up` marker",
223            path.display()
224        )));
225    };
226    let Some(down_start) = source.find(down_marker) else {
227        return Err(DataError::Migration(format!(
228            "migration {} missing `-- +down` marker",
229            path.display()
230        )));
231    };
232    if down_start <= up_start {
233        return Err(DataError::Migration(format!(
234            "migration {} has invalid marker order",
235            path.display()
236        )));
237    }
238    let up_sql = source[up_start + up_marker.len()..down_start]
239        .trim()
240        .to_string();
241    let down_sql = source[down_start + down_marker.len()..].trim().to_string();
242    Ok((up_sql, down_sql))
243}
244
245fn now_unix_ms() -> u64 {
246    SystemTime::now()
247        .duration_since(SystemTime::UNIX_EPOCH)
248        .unwrap_or_default()
249        .as_millis() as u64
250}
251
252#[cfg(test)]
253mod tests {
254    use super::{load_migrations, MigrationEngine};
255    use crate::AdapterKind;
256    use std::{fs, path::PathBuf, time::SystemTime};
257
258    #[test]
259    fn migration_lifecycle_applies_and_rolls_back() {
260        let root = temp_path("shelly_data_migration");
261        let migrations_dir = root.join("migrations");
262        fs::create_dir_all(&migrations_dir).unwrap();
263        fs::write(
264            migrations_dir.join("20260505120000_create_posts.sql"),
265            r#"
266-- +up
267CREATE TABLE posts(id BIGINT PRIMARY KEY);
268-- +down
269DROP TABLE posts;
270"#,
271        )
272        .unwrap();
273
274        let migrations = load_migrations(&migrations_dir).unwrap();
275        let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
276
277        let applied = engine.migrate(&migrations, None).unwrap();
278        assert_eq!(applied.len(), 1);
279        let status = engine.status(&migrations).unwrap();
280        assert_eq!(status.applied.len(), 1);
281        assert_eq!(status.pending.len(), 0);
282
283        let rolled_back = engine.rollback(&migrations, 1).unwrap();
284        assert_eq!(rolled_back.len(), 1);
285        let status = engine.status(&migrations).unwrap();
286        assert_eq!(status.applied.len(), 0);
287
288        fs::remove_dir_all(root).unwrap();
289    }
290
291    #[test]
292    fn migration_loader_rejects_invalid_marker_order() {
293        let root = temp_path("shelly_data_invalid_marker_order");
294        let migrations_dir = root.join("migrations");
295        fs::create_dir_all(&migrations_dir).unwrap();
296        fs::write(
297            migrations_dir.join("20260505130000_invalid.sql"),
298            r#"
299-- +down
300DROP TABLE posts;
301-- +up
302CREATE TABLE posts(id BIGINT PRIMARY KEY);
303"#,
304        )
305        .unwrap();
306
307        let err = load_migrations(&migrations_dir).unwrap_err().to_string();
308        assert!(err.contains("invalid marker order"));
309
310        fs::remove_dir_all(root).unwrap();
311    }
312
313    #[test]
314    fn rollback_fails_when_applied_migration_file_is_missing() {
315        let root = temp_path("shelly_data_missing_migration_file");
316        let migrations_dir = root.join("migrations");
317        fs::create_dir_all(&migrations_dir).unwrap();
318
319        let original_path = migrations_dir.join("20260505140000_create_posts.sql");
320        fs::write(
321            &original_path,
322            r#"
323-- +up
324CREATE TABLE posts(id BIGINT PRIMARY KEY);
325-- +down
326DROP TABLE posts;
327"#,
328        )
329        .unwrap();
330
331        let migrations = load_migrations(&migrations_dir).unwrap();
332        let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
333        engine.migrate(&migrations, None).unwrap();
334
335        fs::remove_file(&original_path).unwrap();
336        let now_missing = load_migrations(&migrations_dir).unwrap();
337        let err = engine.rollback(&now_missing, 1).unwrap_err().to_string();
338        assert!(err.contains("cannot rollback migration"));
339        assert!(err.contains("file is missing"));
340
341        fs::remove_dir_all(root).unwrap();
342    }
343
344    #[test]
345    fn rollback_fails_when_down_sql_is_empty() {
346        let root = temp_path("shelly_data_empty_down_sql");
347        let migrations_dir = root.join("migrations");
348        fs::create_dir_all(&migrations_dir).unwrap();
349        fs::write(
350            migrations_dir.join("20260505150000_create_posts.sql"),
351            r#"
352-- +up
353CREATE TABLE posts(id BIGINT PRIMARY KEY);
354-- +down
355"#,
356        )
357        .unwrap();
358
359        let migrations = load_migrations(&migrations_dir).unwrap();
360        let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
361        engine.migrate(&migrations, None).unwrap();
362        let err = engine.rollback(&migrations, 1).unwrap_err().to_string();
363        assert!(err.contains("empty down SQL"));
364
365        fs::remove_dir_all(root).unwrap();
366    }
367
368    fn temp_path(prefix: &str) -> PathBuf {
369        let nanos = SystemTime::now()
370            .duration_since(SystemTime::UNIX_EPOCH)
371            .unwrap()
372            .as_nanos();
373        std::env::temp_dir().join(format!("{prefix}_{nanos}"))
374    }
375}