1use crate::{
2 adapter::AdapterKind,
3 error::{DataError, DataResult},
4};
5use serde::{Deserialize, Serialize};
6use std::{
7 fs,
8 path::{Path, PathBuf},
9 time::SystemTime,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct Migration {
14 pub id: String,
15 pub name: String,
16 pub up_sql: String,
17 pub down_sql: String,
18 pub path: PathBuf,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub struct AppliedMigration {
23 pub id: String,
24 pub name: String,
25 pub applied_at_unix_ms: u64,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Default)]
29pub struct MigrationStatus {
30 pub applied: Vec<AppliedMigration>,
31 pub pending: Vec<Migration>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct MigrationState {
36 adapter: AdapterKind,
37 applied: Vec<AppliedMigration>,
38}
39
40pub struct MigrationEngine {
41 project_root: PathBuf,
42 adapter: AdapterKind,
43}
44
45impl MigrationEngine {
46 pub fn new(project_root: impl Into<PathBuf>, adapter: AdapterKind) -> DataResult<Self> {
47 if adapter == AdapterKind::None {
48 return Err(DataError::Migration(
49 "adapter is `none`; choose postgres/mysql/sqlite before running migrations"
50 .to_string(),
51 ));
52 }
53 Ok(Self {
54 project_root: project_root.into(),
55 adapter,
56 })
57 }
58
59 pub fn status(&self, all_migrations: &[Migration]) -> DataResult<MigrationStatus> {
60 let state = self.read_state()?;
61 let pending = all_migrations
62 .iter()
63 .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
64 .cloned()
65 .collect::<Vec<_>>();
66 Ok(MigrationStatus {
67 applied: state.applied,
68 pending,
69 })
70 }
71
72 pub fn migrate(
73 &self,
74 all_migrations: &[Migration],
75 steps: Option<usize>,
76 ) -> DataResult<Vec<AppliedMigration>> {
77 let mut state = self.read_state()?;
78 let mut applied_now = Vec::<AppliedMigration>::new();
79
80 let pending = all_migrations
81 .iter()
82 .filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
83 .cloned()
84 .collect::<Vec<_>>();
85 let pending = if let Some(steps) = steps {
86 pending.into_iter().take(steps).collect::<Vec<_>>()
87 } else {
88 pending
89 };
90
91 for migration in pending {
92 if migration.up_sql.trim().is_empty() {
93 return Err(DataError::Migration(format!(
94 "migration {} has empty up SQL",
95 migration.path.display()
96 )));
97 }
98 let applied = AppliedMigration {
99 id: migration.id,
100 name: migration.name,
101 applied_at_unix_ms: now_unix_ms(),
102 };
103 state.applied.push(applied.clone());
104 applied_now.push(applied);
105 }
106
107 self.write_state(&state)?;
108 Ok(applied_now)
109 }
110
111 pub fn rollback(
112 &self,
113 all_migrations: &[Migration],
114 steps: usize,
115 ) -> DataResult<Vec<AppliedMigration>> {
116 let mut state = self.read_state()?;
117 let mut rolled_back = Vec::<AppliedMigration>::new();
118 let steps = steps.max(1);
119
120 for _ in 0..steps {
121 let Some(last) = state.applied.pop() else {
122 break;
123 };
124 let Some(definition) = all_migrations
125 .iter()
126 .find(|migration| migration.id == last.id)
127 else {
128 return Err(DataError::Migration(format!(
129 "cannot rollback migration `{}` because file is missing",
130 last.id
131 )));
132 };
133 if definition.down_sql.trim().is_empty() {
134 return Err(DataError::Migration(format!(
135 "migration {} has empty down SQL",
136 definition.path.display()
137 )));
138 }
139 rolled_back.push(last);
140 }
141
142 self.write_state(&state)?;
143 Ok(rolled_back)
144 }
145
146 fn state_path(&self) -> PathBuf {
147 self.project_root
148 .join(".shelly")
149 .join("migrations")
150 .join(format!("{}.json", self.adapter.as_str()))
151 }
152
153 fn read_state(&self) -> DataResult<MigrationState> {
154 let state_path = self.state_path();
155 if !state_path.exists() {
156 return Ok(MigrationState {
157 adapter: self.adapter,
158 applied: Vec::new(),
159 });
160 }
161 let raw = fs::read_to_string(state_path)?;
162 let mut state: MigrationState = serde_json::from_str(&raw)?;
163 state.adapter = self.adapter;
164 Ok(state)
165 }
166
167 fn write_state(&self, state: &MigrationState) -> DataResult<()> {
168 let state_path = self.state_path();
169 if let Some(parent) = state_path.parent() {
170 fs::create_dir_all(parent)?;
171 }
172 let body = serde_json::to_string_pretty(state)?;
173 fs::write(state_path, format!("{body}\n"))?;
174 Ok(())
175 }
176}
177
178pub fn load_migrations(dir: &Path) -> DataResult<Vec<Migration>> {
179 if !dir.exists() {
180 return Ok(Vec::new());
181 }
182
183 let mut entries = fs::read_dir(dir)?
184 .filter_map(|entry| entry.ok())
185 .map(|entry| entry.path())
186 .filter(|path| path.extension().is_some_and(|extension| extension == "sql"))
187 .collect::<Vec<_>>();
188 entries.sort();
189
190 let mut migrations = Vec::with_capacity(entries.len());
191 for path in entries {
192 let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else {
193 continue;
194 };
195 let Some((id, name)) = parse_file_id_name(file_name) else {
196 continue;
197 };
198 let source = fs::read_to_string(&path)?;
199 let (up_sql, down_sql) = parse_up_down(&source, &path)?;
200 migrations.push(Migration {
201 id,
202 name,
203 up_sql,
204 down_sql,
205 path,
206 });
207 }
208 Ok(migrations)
209}
210
211fn parse_file_id_name(file_name: &str) -> Option<(String, String)> {
212 let trimmed = file_name.strip_suffix(".sql")?;
213 let (id, name) = trimmed.split_once('_')?;
214 Some((id.to_string(), name.to_string()))
215}
216
217fn parse_up_down(source: &str, path: &Path) -> DataResult<(String, String)> {
218 let up_marker = "-- +up";
219 let down_marker = "-- +down";
220 let Some(up_start) = source.find(up_marker) else {
221 return Err(DataError::Migration(format!(
222 "migration {} missing `-- +up` marker",
223 path.display()
224 )));
225 };
226 let Some(down_start) = source.find(down_marker) else {
227 return Err(DataError::Migration(format!(
228 "migration {} missing `-- +down` marker",
229 path.display()
230 )));
231 };
232 if down_start <= up_start {
233 return Err(DataError::Migration(format!(
234 "migration {} has invalid marker order",
235 path.display()
236 )));
237 }
238 let up_sql = source[up_start + up_marker.len()..down_start]
239 .trim()
240 .to_string();
241 let down_sql = source[down_start + down_marker.len()..].trim().to_string();
242 Ok((up_sql, down_sql))
243}
244
245fn now_unix_ms() -> u64 {
246 SystemTime::now()
247 .duration_since(SystemTime::UNIX_EPOCH)
248 .unwrap_or_default()
249 .as_millis() as u64
250}
251
252#[cfg(test)]
253mod tests {
254 use super::{load_migrations, MigrationEngine};
255 use crate::AdapterKind;
256 use std::{fs, path::PathBuf, time::SystemTime};
257
258 #[test]
259 fn migration_lifecycle_applies_and_rolls_back() {
260 let root = temp_path("shelly_data_migration");
261 let migrations_dir = root.join("migrations");
262 fs::create_dir_all(&migrations_dir).unwrap();
263 fs::write(
264 migrations_dir.join("20260505120000_create_posts.sql"),
265 r#"
266-- +up
267CREATE TABLE posts(id BIGINT PRIMARY KEY);
268-- +down
269DROP TABLE posts;
270"#,
271 )
272 .unwrap();
273
274 let migrations = load_migrations(&migrations_dir).unwrap();
275 let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
276
277 let applied = engine.migrate(&migrations, None).unwrap();
278 assert_eq!(applied.len(), 1);
279 let status = engine.status(&migrations).unwrap();
280 assert_eq!(status.applied.len(), 1);
281 assert_eq!(status.pending.len(), 0);
282
283 let rolled_back = engine.rollback(&migrations, 1).unwrap();
284 assert_eq!(rolled_back.len(), 1);
285 let status = engine.status(&migrations).unwrap();
286 assert_eq!(status.applied.len(), 0);
287
288 fs::remove_dir_all(root).unwrap();
289 }
290
291 #[test]
292 fn migration_loader_rejects_invalid_marker_order() {
293 let root = temp_path("shelly_data_invalid_marker_order");
294 let migrations_dir = root.join("migrations");
295 fs::create_dir_all(&migrations_dir).unwrap();
296 fs::write(
297 migrations_dir.join("20260505130000_invalid.sql"),
298 r#"
299-- +down
300DROP TABLE posts;
301-- +up
302CREATE TABLE posts(id BIGINT PRIMARY KEY);
303"#,
304 )
305 .unwrap();
306
307 let err = load_migrations(&migrations_dir).unwrap_err().to_string();
308 assert!(err.contains("invalid marker order"));
309
310 fs::remove_dir_all(root).unwrap();
311 }
312
313 #[test]
314 fn rollback_fails_when_applied_migration_file_is_missing() {
315 let root = temp_path("shelly_data_missing_migration_file");
316 let migrations_dir = root.join("migrations");
317 fs::create_dir_all(&migrations_dir).unwrap();
318
319 let original_path = migrations_dir.join("20260505140000_create_posts.sql");
320 fs::write(
321 &original_path,
322 r#"
323-- +up
324CREATE TABLE posts(id BIGINT PRIMARY KEY);
325-- +down
326DROP TABLE posts;
327"#,
328 )
329 .unwrap();
330
331 let migrations = load_migrations(&migrations_dir).unwrap();
332 let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
333 engine.migrate(&migrations, None).unwrap();
334
335 fs::remove_file(&original_path).unwrap();
336 let now_missing = load_migrations(&migrations_dir).unwrap();
337 let err = engine.rollback(&now_missing, 1).unwrap_err().to_string();
338 assert!(err.contains("cannot rollback migration"));
339 assert!(err.contains("file is missing"));
340
341 fs::remove_dir_all(root).unwrap();
342 }
343
344 #[test]
345 fn rollback_fails_when_down_sql_is_empty() {
346 let root = temp_path("shelly_data_empty_down_sql");
347 let migrations_dir = root.join("migrations");
348 fs::create_dir_all(&migrations_dir).unwrap();
349 fs::write(
350 migrations_dir.join("20260505150000_create_posts.sql"),
351 r#"
352-- +up
353CREATE TABLE posts(id BIGINT PRIMARY KEY);
354-- +down
355"#,
356 )
357 .unwrap();
358
359 let migrations = load_migrations(&migrations_dir).unwrap();
360 let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
361 engine.migrate(&migrations, None).unwrap();
362 let err = engine.rollback(&migrations, 1).unwrap_err().to_string();
363 assert!(err.contains("empty down SQL"));
364
365 fs::remove_dir_all(root).unwrap();
366 }
367
368 fn temp_path(prefix: &str) -> PathBuf {
369 let nanos = SystemTime::now()
370 .duration_since(SystemTime::UNIX_EPOCH)
371 .unwrap()
372 .as_nanos();
373 std::env::temp_dir().join(format!("{prefix}_{nanos}"))
374 }
375}