Skip to main content

shelly_data/
migration.rs

1use crate::{
2    adapter::AdapterKind,
3    error::{DataError, DataResult},
4};
5use serde::{Deserialize, Serialize};
6use std::{
7    fs,
8    path::{Path, PathBuf},
9    time::SystemTime,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct Migration {
14    pub id: String,
15    pub name: String,
16    pub up_sql: String,
17    pub down_sql: String,
18    pub path: PathBuf,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub struct AppliedMigration {
23    pub id: String,
24    pub name: String,
25    pub applied_at_unix_ms: u64,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Default)]
29pub struct MigrationStatus {
30    pub applied: Vec<AppliedMigration>,
31    pub pending: Vec<Migration>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct MigrationState {
36    adapter: AdapterKind,
37    applied: Vec<AppliedMigration>,
38}
39
40pub struct MigrationEngine {
41    project_root: PathBuf,
42    adapter: AdapterKind,
43}
44
45impl MigrationEngine {
46    pub fn new(project_root: impl Into<PathBuf>, adapter: AdapterKind) -> DataResult<Self> {
47        if adapter == AdapterKind::None {
48            return Err(DataError::Migration(
49                "adapter is `none`; choose a SQL adapter before running migrations".to_string(),
50            ));
51        }
52        if !adapter.supports_migrations() {
53            return Err(DataError::Migration(format!(
54                "adapter `{}` does not support SQL migrations",
55                adapter.as_str()
56            )));
57        }
58        Ok(Self {
59            project_root: project_root.into(),
60            adapter,
61        })
62    }
63
64    pub fn status(&self, all_migrations: &[Migration]) -> DataResult<MigrationStatus> {
65        let state = self.read_state()?;
66        let pending = all_migrations
67            .iter()
68            .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
69            .cloned()
70            .collect::<Vec<_>>();
71        Ok(MigrationStatus {
72            applied: state.applied,
73            pending,
74        })
75    }
76
77    pub fn migrate(
78        &self,
79        all_migrations: &[Migration],
80        steps: Option<usize>,
81    ) -> DataResult<Vec<AppliedMigration>> {
82        let mut state = self.read_state()?;
83        let mut applied_now = Vec::<AppliedMigration>::new();
84
85        let pending = all_migrations
86            .iter()
87            .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
88            .cloned()
89            .collect::<Vec<_>>();
90        let pending = if let Some(steps) = steps {
91            pending.into_iter().take(steps).collect::<Vec<_>>()
92        } else {
93            pending
94        };
95
96        for migration in pending {
97            if migration.up_sql.trim().is_empty() {
98                return Err(DataError::Migration(format!(
99                    "migration {} has empty up SQL",
100                    migration.path.display()
101                )));
102            }
103            let applied = AppliedMigration {
104                id: migration.id,
105                name: migration.name,
106                applied_at_unix_ms: now_unix_ms(),
107            };
108            state.applied.push(applied.clone());
109            applied_now.push(applied);
110        }
111
112        self.write_state(&state)?;
113        Ok(applied_now)
114    }
115
116    pub fn rollback(
117        &self,
118        all_migrations: &[Migration],
119        steps: usize,
120    ) -> DataResult<Vec<AppliedMigration>> {
121        let mut state = self.read_state()?;
122        let mut rolled_back = Vec::<AppliedMigration>::new();
123        let steps = steps.max(1);
124
125        for _ in 0..steps {
126            let Some(last) = state.applied.pop() else {
127                break;
128            };
129            let Some(definition) = all_migrations
130                .iter()
131                .find(|migration| migration.id == last.id)
132            else {
133                return Err(DataError::Migration(format!(
134                    "cannot rollback migration `{}` because file is missing",
135                    last.id
136                )));
137            };
138            if definition.down_sql.trim().is_empty() {
139                return Err(DataError::Migration(format!(
140                    "migration {} has empty down SQL",
141                    definition.path.display()
142                )));
143            }
144            rolled_back.push(last);
145        }
146
147        self.write_state(&state)?;
148        Ok(rolled_back)
149    }
150
151    fn state_path(&self) -> PathBuf {
152        self.project_root
153            .join(".shelly")
154            .join("migrations")
155            .join(format!("{}.json", self.adapter.as_str()))
156    }
157
158    fn read_state(&self) -> DataResult<MigrationState> {
159        let state_path = self.state_path();
160        if !state_path.exists() {
161            return Ok(MigrationState {
162                adapter: self.adapter,
163                applied: Vec::new(),
164            });
165        }
166        let raw = fs::read_to_string(state_path)?;
167        let mut state: MigrationState = serde_json::from_str(&raw)?;
168        state.adapter = self.adapter;
169        Ok(state)
170    }
171
172    fn write_state(&self, state: &MigrationState) -> DataResult<()> {
173        let state_path = self.state_path();
174        if let Some(parent) = state_path.parent() {
175            fs::create_dir_all(parent)?;
176        }
177        let body = serde_json::to_string_pretty(state)?;
178        fs::write(state_path, format!("{body}\n"))?;
179        Ok(())
180    }
181}
182
183pub fn load_migrations(dir: &Path) -> DataResult<Vec<Migration>> {
184    if !dir.exists() {
185        return Ok(Vec::new());
186    }
187
188    let mut entries = fs::read_dir(dir)?
189        .filter_map(|entry| entry.ok())
190        .map(|entry| entry.path())
191        .filter(|path| path.extension().is_some_and(|extension| extension == "sql"))
192        .collect::<Vec<_>>();
193    entries.sort();
194
195    let mut migrations = Vec::with_capacity(entries.len());
196    for path in entries {
197        let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else {
198            continue;
199        };
200        let Some((id, name)) = parse_file_id_name(file_name) else {
201            continue;
202        };
203        let source = fs::read_to_string(&path)?;
204        let (up_sql, down_sql) = parse_up_down(&source, &path)?;
205        migrations.push(Migration {
206            id,
207            name,
208            up_sql,
209            down_sql,
210            path,
211        });
212    }
213    Ok(migrations)
214}
215
216fn parse_file_id_name(file_name: &str) -> Option<(String, String)> {
217    let trimmed = file_name.strip_suffix(".sql")?;
218    let (id, name) = trimmed.split_once('_')?;
219    Some((id.to_string(), name.to_string()))
220}
221
222fn parse_up_down(source: &str, path: &Path) -> DataResult<(String, String)> {
223    let up_marker = "-- +up";
224    let down_marker = "-- +down";
225    let Some(up_start) = source.find(up_marker) else {
226        return Err(DataError::Migration(format!(
227            "migration {} missing `-- +up` marker",
228            path.display()
229        )));
230    };
231    let Some(down_start) = source.find(down_marker) else {
232        return Err(DataError::Migration(format!(
233            "migration {} missing `-- +down` marker",
234            path.display()
235        )));
236    };
237    if down_start <= up_start {
238        return Err(DataError::Migration(format!(
239            "migration {} has invalid marker order",
240            path.display()
241        )));
242    }
243    let up_sql = source[up_start + up_marker.len()..down_start]
244        .trim()
245        .to_string();
246    let down_sql = source[down_start + down_marker.len()..].trim().to_string();
247    Ok((up_sql, down_sql))
248}
249
250fn now_unix_ms() -> u64 {
251    SystemTime::now()
252        .duration_since(SystemTime::UNIX_EPOCH)
253        .unwrap_or_default()
254        .as_millis() as u64
255}
256
257#[cfg(test)]
258mod tests {
259    use super::{load_migrations, MigrationEngine};
260    use crate::AdapterKind;
261    use std::{fs, path::PathBuf, time::SystemTime};
262
263    #[test]
264    fn migration_lifecycle_applies_and_rolls_back() {
265        let root = temp_path("shelly_data_migration");
266        let migrations_dir = root.join("migrations");
267        fs::create_dir_all(&migrations_dir).unwrap();
268        fs::write(
269            migrations_dir.join("20260505120000_create_posts.sql"),
270            r#"
271-- +up
272CREATE TABLE posts(id BIGINT PRIMARY KEY);
273-- +down
274DROP TABLE posts;
275"#,
276        )
277        .unwrap();
278
279        let migrations = load_migrations(&migrations_dir).unwrap();
280        let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
281
282        let applied = engine.migrate(&migrations, None).unwrap();
283        assert_eq!(applied.len(), 1);
284        let status = engine.status(&migrations).unwrap();
285        assert_eq!(status.applied.len(), 1);
286        assert_eq!(status.pending.len(), 0);
287
288        let rolled_back = engine.rollback(&migrations, 1).unwrap();
289        assert_eq!(rolled_back.len(), 1);
290        let status = engine.status(&migrations).unwrap();
291        assert_eq!(status.applied.len(), 0);
292
293        fs::remove_dir_all(root).unwrap();
294    }
295
296    #[test]
297    fn migration_loader_rejects_invalid_marker_order() {
298        let root = temp_path("shelly_data_invalid_marker_order");
299        let migrations_dir = root.join("migrations");
300        fs::create_dir_all(&migrations_dir).unwrap();
301        fs::write(
302            migrations_dir.join("20260505130000_invalid.sql"),
303            r#"
304-- +down
305DROP TABLE posts;
306-- +up
307CREATE TABLE posts(id BIGINT PRIMARY KEY);
308"#,
309        )
310        .unwrap();
311
312        let err = load_migrations(&migrations_dir).unwrap_err().to_string();
313        assert!(err.contains("invalid marker order"));
314
315        fs::remove_dir_all(root).unwrap();
316    }
317
318    #[test]
319    fn rollback_fails_when_applied_migration_file_is_missing() {
320        let root = temp_path("shelly_data_missing_migration_file");
321        let migrations_dir = root.join("migrations");
322        fs::create_dir_all(&migrations_dir).unwrap();
323
324        let original_path = migrations_dir.join("20260505140000_create_posts.sql");
325        fs::write(
326            &original_path,
327            r#"
328-- +up
329CREATE TABLE posts(id BIGINT PRIMARY KEY);
330-- +down
331DROP TABLE posts;
332"#,
333        )
334        .unwrap();
335
336        let migrations = load_migrations(&migrations_dir).unwrap();
337        let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
338        engine.migrate(&migrations, None).unwrap();
339
340        fs::remove_file(&original_path).unwrap();
341        let now_missing = load_migrations(&migrations_dir).unwrap();
342        let err = engine.rollback(&now_missing, 1).unwrap_err().to_string();
343        assert!(err.contains("cannot rollback migration"));
344        assert!(err.contains("file is missing"));
345
346        fs::remove_dir_all(root).unwrap();
347    }
348
349    #[test]
350    fn rollback_fails_when_down_sql_is_empty() {
351        let root = temp_path("shelly_data_empty_down_sql");
352        let migrations_dir = root.join("migrations");
353        fs::create_dir_all(&migrations_dir).unwrap();
354        fs::write(
355            migrations_dir.join("20260505150000_create_posts.sql"),
356            r#"
357-- +up
358CREATE TABLE posts(id BIGINT PRIMARY KEY);
359-- +down
360"#,
361        )
362        .unwrap();
363
364        let migrations = load_migrations(&migrations_dir).unwrap();
365        let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
366        engine.migrate(&migrations, None).unwrap();
367        let err = engine.rollback(&migrations, 1).unwrap_err().to_string();
368        assert!(err.contains("empty down SQL"));
369
370        fs::remove_dir_all(root).unwrap();
371    }
372
373    fn temp_path(prefix: &str) -> PathBuf {
374        let nanos = SystemTime::now()
375            .duration_since(SystemTime::UNIX_EPOCH)
376            .unwrap()
377            .as_nanos();
378        std::env::temp_dir().join(format!("{prefix}_{nanos}"))
379    }
380}