1use crate::{
2 adapter::AdapterKind,
3 error::{DataError, DataResult},
4};
5use serde::{Deserialize, Serialize};
6use std::{
7 fs,
8 path::{Path, PathBuf},
9 time::SystemTime,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct Migration {
14 pub id: String,
15 pub name: String,
16 pub up_sql: String,
17 pub down_sql: String,
18 pub path: PathBuf,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub struct AppliedMigration {
23 pub id: String,
24 pub name: String,
25 pub applied_at_unix_ms: u64,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Default)]
29pub struct MigrationStatus {
30 pub applied: Vec<AppliedMigration>,
31 pub pending: Vec<Migration>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct MigrationState {
36 adapter: AdapterKind,
37 applied: Vec<AppliedMigration>,
38}
39
40pub struct MigrationEngine {
41 project_root: PathBuf,
42 adapter: AdapterKind,
43}
44
45impl MigrationEngine {
46 pub fn new(project_root: impl Into<PathBuf>, adapter: AdapterKind) -> DataResult<Self> {
47 if adapter == AdapterKind::None {
48 return Err(DataError::Migration(
49 "adapter is `none`; choose a SQL adapter before running migrations".to_string(),
50 ));
51 }
52 if !adapter.supports_migrations() {
53 return Err(DataError::Migration(format!(
54 "adapter `{}` does not support SQL migrations",
55 adapter.as_str()
56 )));
57 }
58 Ok(Self {
59 project_root: project_root.into(),
60 adapter,
61 })
62 }
63
64 pub fn status(&self, all_migrations: &[Migration]) -> DataResult<MigrationStatus> {
65 let state = self.read_state()?;
66 let pending = all_migrations
67 .iter()
68 .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
69 .cloned()
70 .collect::<Vec<_>>();
71 Ok(MigrationStatus {
72 applied: state.applied,
73 pending,
74 })
75 }
76
77 pub fn migrate(
78 &self,
79 all_migrations: &[Migration],
80 steps: Option<usize>,
81 ) -> DataResult<Vec<AppliedMigration>> {
82 let mut state = self.read_state()?;
83 let mut applied_now = Vec::<AppliedMigration>::new();
84
85 let pending = all_migrations
86 .iter()
87 .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
88 .cloned()
89 .collect::<Vec<_>>();
90 let pending = if let Some(steps) = steps {
91 pending.into_iter().take(steps).collect::<Vec<_>>()
92 } else {
93 pending
94 };
95
96 for migration in pending {
97 if migration.up_sql.trim().is_empty() {
98 return Err(DataError::Migration(format!(
99 "migration {} has empty up SQL",
100 migration.path.display()
101 )));
102 }
103 let applied = AppliedMigration {
104 id: migration.id,
105 name: migration.name,
106 applied_at_unix_ms: now_unix_ms(),
107 };
108 state.applied.push(applied.clone());
109 applied_now.push(applied);
110 }
111
112 self.write_state(&state)?;
113 Ok(applied_now)
114 }
115
116 pub fn rollback(
117 &self,
118 all_migrations: &[Migration],
119 steps: usize,
120 ) -> DataResult<Vec<AppliedMigration>> {
121 let mut state = self.read_state()?;
122 let mut rolled_back = Vec::<AppliedMigration>::new();
123 let steps = steps.max(1);
124
125 for _ in 0..steps {
126 let Some(last) = state.applied.pop() else {
127 break;
128 };
129 let Some(definition) = all_migrations
130 .iter()
131 .find(|migration| migration.id == last.id)
132 else {
133 return Err(DataError::Migration(format!(
134 "cannot rollback migration `{}` because file is missing",
135 last.id
136 )));
137 };
138 if definition.down_sql.trim().is_empty() {
139 return Err(DataError::Migration(format!(
140 "migration {} has empty down SQL",
141 definition.path.display()
142 )));
143 }
144 rolled_back.push(last);
145 }
146
147 self.write_state(&state)?;
148 Ok(rolled_back)
149 }
150
151 fn state_path(&self) -> PathBuf {
152 self.project_root
153 .join(".shelly")
154 .join("migrations")
155 .join(format!("{}.json", self.adapter.as_str()))
156 }
157
158 fn read_state(&self) -> DataResult<MigrationState> {
159 let state_path = self.state_path();
160 if !state_path.exists() {
161 return Ok(MigrationState {
162 adapter: self.adapter,
163 applied: Vec::new(),
164 });
165 }
166 let raw = fs::read_to_string(state_path)?;
167 let mut state: MigrationState = serde_json::from_str(&raw)?;
168 state.adapter = self.adapter;
169 Ok(state)
170 }
171
172 fn write_state(&self, state: &MigrationState) -> DataResult<()> {
173 let state_path = self.state_path();
174 if let Some(parent) = state_path.parent() {
175 fs::create_dir_all(parent)?;
176 }
177 let body = serde_json::to_string_pretty(state)?;
178 fs::write(state_path, format!("{body}\n"))?;
179 Ok(())
180 }
181}
182
183pub fn load_migrations(dir: &Path) -> DataResult<Vec<Migration>> {
184 if !dir.exists() {
185 return Ok(Vec::new());
186 }
187
188 let mut entries = fs::read_dir(dir)?
189 .filter_map(|entry| entry.ok())
190 .map(|entry| entry.path())
191 .filter(|path| path.extension().is_some_and(|extension| extension == "sql"))
192 .collect::<Vec<_>>();
193 entries.sort();
194
195 let mut migrations = Vec::with_capacity(entries.len());
196 for path in entries {
197 let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else {
198 continue;
199 };
200 let Some((id, name)) = parse_file_id_name(file_name) else {
201 continue;
202 };
203 let source = fs::read_to_string(&path)?;
204 let (up_sql, down_sql) = parse_up_down(&source, &path)?;
205 migrations.push(Migration {
206 id,
207 name,
208 up_sql,
209 down_sql,
210 path,
211 });
212 }
213 Ok(migrations)
214}
215
216fn parse_file_id_name(file_name: &str) -> Option<(String, String)> {
217 let trimmed = file_name.strip_suffix(".sql")?;
218 let (id, name) = trimmed.split_once('_')?;
219 Some((id.to_string(), name.to_string()))
220}
221
222fn parse_up_down(source: &str, path: &Path) -> DataResult<(String, String)> {
223 let up_marker = "-- +up";
224 let down_marker = "-- +down";
225 let Some(up_start) = source.find(up_marker) else {
226 return Err(DataError::Migration(format!(
227 "migration {} missing `-- +up` marker",
228 path.display()
229 )));
230 };
231 let Some(down_start) = source.find(down_marker) else {
232 return Err(DataError::Migration(format!(
233 "migration {} missing `-- +down` marker",
234 path.display()
235 )));
236 };
237 if down_start <= up_start {
238 return Err(DataError::Migration(format!(
239 "migration {} has invalid marker order",
240 path.display()
241 )));
242 }
243 let up_sql = source[up_start + up_marker.len()..down_start]
244 .trim()
245 .to_string();
246 let down_sql = source[down_start + down_marker.len()..].trim().to_string();
247 Ok((up_sql, down_sql))
248}
249
250fn now_unix_ms() -> u64 {
251 SystemTime::now()
252 .duration_since(SystemTime::UNIX_EPOCH)
253 .unwrap_or_default()
254 .as_millis() as u64
255}
256
257#[cfg(test)]
258mod tests {
259 use super::{load_migrations, MigrationEngine};
260 use crate::AdapterKind;
261 use std::{fs, path::PathBuf, time::SystemTime};
262
263 #[test]
264 fn migration_lifecycle_applies_and_rolls_back() {
265 let root = temp_path("shelly_data_migration");
266 let migrations_dir = root.join("migrations");
267 fs::create_dir_all(&migrations_dir).unwrap();
268 fs::write(
269 migrations_dir.join("20260505120000_create_posts.sql"),
270 r#"
271-- +up
272CREATE TABLE posts(id BIGINT PRIMARY KEY);
273-- +down
274DROP TABLE posts;
275"#,
276 )
277 .unwrap();
278
279 let migrations = load_migrations(&migrations_dir).unwrap();
280 let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
281
282 let applied = engine.migrate(&migrations, None).unwrap();
283 assert_eq!(applied.len(), 1);
284 let status = engine.status(&migrations).unwrap();
285 assert_eq!(status.applied.len(), 1);
286 assert_eq!(status.pending.len(), 0);
287
288 let rolled_back = engine.rollback(&migrations, 1).unwrap();
289 assert_eq!(rolled_back.len(), 1);
290 let status = engine.status(&migrations).unwrap();
291 assert_eq!(status.applied.len(), 0);
292
293 fs::remove_dir_all(root).unwrap();
294 }
295
296 #[test]
297 fn migration_loader_rejects_invalid_marker_order() {
298 let root = temp_path("shelly_data_invalid_marker_order");
299 let migrations_dir = root.join("migrations");
300 fs::create_dir_all(&migrations_dir).unwrap();
301 fs::write(
302 migrations_dir.join("20260505130000_invalid.sql"),
303 r#"
304-- +down
305DROP TABLE posts;
306-- +up
307CREATE TABLE posts(id BIGINT PRIMARY KEY);
308"#,
309 )
310 .unwrap();
311
312 let err = load_migrations(&migrations_dir).unwrap_err().to_string();
313 assert!(err.contains("invalid marker order"));
314
315 fs::remove_dir_all(root).unwrap();
316 }
317
318 #[test]
319 fn rollback_fails_when_applied_migration_file_is_missing() {
320 let root = temp_path("shelly_data_missing_migration_file");
321 let migrations_dir = root.join("migrations");
322 fs::create_dir_all(&migrations_dir).unwrap();
323
324 let original_path = migrations_dir.join("20260505140000_create_posts.sql");
325 fs::write(
326 &original_path,
327 r#"
328-- +up
329CREATE TABLE posts(id BIGINT PRIMARY KEY);
330-- +down
331DROP TABLE posts;
332"#,
333 )
334 .unwrap();
335
336 let migrations = load_migrations(&migrations_dir).unwrap();
337 let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
338 engine.migrate(&migrations, None).unwrap();
339
340 fs::remove_file(&original_path).unwrap();
341 let now_missing = load_migrations(&migrations_dir).unwrap();
342 let err = engine.rollback(&now_missing, 1).unwrap_err().to_string();
343 assert!(err.contains("cannot rollback migration"));
344 assert!(err.contains("file is missing"));
345
346 fs::remove_dir_all(root).unwrap();
347 }
348
349 #[test]
350 fn rollback_fails_when_down_sql_is_empty() {
351 let root = temp_path("shelly_data_empty_down_sql");
352 let migrations_dir = root.join("migrations");
353 fs::create_dir_all(&migrations_dir).unwrap();
354 fs::write(
355 migrations_dir.join("20260505150000_create_posts.sql"),
356 r#"
357-- +up
358CREATE TABLE posts(id BIGINT PRIMARY KEY);
359-- +down
360"#,
361 )
362 .unwrap();
363
364 let migrations = load_migrations(&migrations_dir).unwrap();
365 let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
366 engine.migrate(&migrations, None).unwrap();
367 let err = engine.rollback(&migrations, 1).unwrap_err().to_string();
368 assert!(err.contains("empty down SQL"));
369
370 fs::remove_dir_all(root).unwrap();
371 }
372
373 fn temp_path(prefix: &str) -> PathBuf {
374 let nanos = SystemTime::now()
375 .duration_since(SystemTime::UNIX_EPOCH)
376 .unwrap()
377 .as_nanos();
378 std::env::temp_dir().join(format!("{prefix}_{nanos}"))
379 }
380}