Skip to main content

sql_orm_migrate/
filesystem.rs

1use crate::ModelSnapshot;
2use sql_orm_core::OrmError;
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7const MIGRATIONS_DIR: &str = "migrations";
8const ORM_VERSION: &str = env!("CARGO_PKG_VERSION");
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct MigrationScaffold {
12    pub id: String,
13    pub name: String,
14    pub directory: PathBuf,
15}
16
17impl MigrationScaffold {
18    pub fn up_sql_path(&self) -> PathBuf {
19        self.directory.join("up.sql")
20    }
21
22    pub fn down_sql_path(&self) -> PathBuf {
23        self.directory.join("down.sql")
24    }
25
26    pub fn snapshot_path(&self) -> PathBuf {
27        self.directory.join("model_snapshot.json")
28    }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct MigrationEntry {
33    pub id: String,
34    pub name: String,
35    pub directory: PathBuf,
36    pub up_sql_path: PathBuf,
37    pub down_sql_path: PathBuf,
38    pub snapshot_path: PathBuf,
39}
40
41pub fn create_migration_scaffold(root: &Path, name: &str) -> Result<MigrationScaffold, OrmError> {
42    create_migration_scaffold_with_snapshot(root, name, &ModelSnapshot::default())
43}
44
45pub fn create_migration_scaffold_with_snapshot(
46    root: &Path,
47    name: &str,
48    snapshot: &ModelSnapshot,
49) -> Result<MigrationScaffold, OrmError> {
50    if name.trim().is_empty() {
51        return Err(OrmError::new("migration name cannot be empty"));
52    }
53
54    let slug = slugify(name);
55    let timestamp = migration_timestamp()?;
56    let id = format!("{timestamp}_{slug}");
57    let migrations_dir = root.join(MIGRATIONS_DIR);
58    let directory = migrations_dir.join(&id);
59
60    fs::create_dir_all(&directory)
61        .map_err(|_| OrmError::new("failed to create migration directory"))?;
62    fs::write(directory.join("up.sql"), initial_up_sql_template(&id))
63        .map_err(|_| OrmError::new("failed to write migration up.sql"))?;
64    fs::write(directory.join("down.sql"), initial_down_sql_template(&id))
65        .map_err(|_| OrmError::new("failed to write migration down.sql"))?;
66    write_model_snapshot(&directory.join("model_snapshot.json"), snapshot)?;
67
68    Ok(MigrationScaffold {
69        id,
70        name: name.to_string(),
71        directory,
72    })
73}
74
75fn initial_up_sql_template(id: &str) -> String {
76    format!("-- Migration: {id}\n-- SQL Server DDL for this migration.\n")
77}
78
79fn initial_down_sql_template(id: &str) -> String {
80    format!(
81        "-- Migration: {id}\n-- Manual rollback SQL for this editable migration.\n-- The current MVP does not execute down.sql automatically.\n"
82    )
83}
84
85pub fn write_model_snapshot(path: &Path, snapshot: &ModelSnapshot) -> Result<(), OrmError> {
86    fs::write(path, snapshot.to_json_pretty()?)
87        .map_err(|_| OrmError::new("failed to write migration model snapshot"))
88}
89
90pub fn write_migration_up_sql(path: &Path, sql_statements: &[String]) -> Result<(), OrmError> {
91    let sql = if sql_statements.is_empty() {
92        String::from("-- No schema changes detected.\n")
93    } else {
94        let mut sql = sql_statements.join(";\n\n");
95        sql.push_str(";\n");
96        sql
97    };
98
99    fs::write(path, sql).map_err(|_| OrmError::new("failed to write migration up.sql"))
100}
101
102pub fn write_migration_down_sql(path: &Path, sql_statements: &[String]) -> Result<(), OrmError> {
103    let sql = if sql_statements.is_empty() {
104        String::from("-- No reversible schema changes detected.\n")
105    } else {
106        let mut sql = sql_statements.join(";\n\n");
107        sql.push_str(";\n");
108        sql
109    };
110
111    fs::write(path, sql).map_err(|_| OrmError::new("failed to write migration down.sql"))
112}
113
114pub fn read_model_snapshot(path: &Path) -> Result<ModelSnapshot, OrmError> {
115    let json = fs::read_to_string(path)
116        .map_err(|_| OrmError::new("failed to read migration model snapshot"))?;
117    ModelSnapshot::from_json(&json)
118}
119
120pub fn list_migrations(root: &Path) -> Result<Vec<MigrationEntry>, OrmError> {
121    let migrations_dir = root.join(MIGRATIONS_DIR);
122    if !migrations_dir.exists() {
123        return Ok(Vec::new());
124    }
125
126    let mut entries = fs::read_dir(&migrations_dir)
127        .map_err(|_| OrmError::new("failed to read migrations directory"))?
128        .filter_map(Result::ok)
129        .filter(|entry| entry.file_type().map(|kind| kind.is_dir()).unwrap_or(false))
130        .filter_map(|entry| parse_migration_entry(entry.path()))
131        .collect::<Vec<_>>();
132
133    entries.sort_by(|left, right| left.id.cmp(&right.id));
134    Ok(entries)
135}
136
137pub fn latest_migration(root: &Path) -> Result<Option<MigrationEntry>, OrmError> {
138    Ok(list_migrations(root)?.into_iter().last())
139}
140
141pub fn read_latest_model_snapshot(
142    root: &Path,
143) -> Result<Option<(MigrationEntry, ModelSnapshot)>, OrmError> {
144    let Some(migration) = latest_migration(root)? else {
145        return Ok(None);
146    };
147
148    let snapshot = read_model_snapshot(&migration.snapshot_path)?;
149    Ok(Some((migration, snapshot)))
150}
151
152pub fn build_database_update_script(
153    root: &Path,
154    history_table_sql: &str,
155) -> Result<String, OrmError> {
156    let migrations = list_migrations(root)?;
157    let mut script = vec![
158        "-- sql-orm database update".to_string(),
159        "SET ANSI_NULLS ON;".to_string(),
160        "SET ANSI_PADDING ON;".to_string(),
161        "SET ANSI_WARNINGS ON;".to_string(),
162        "SET ARITHABORT ON;".to_string(),
163        "SET CONCAT_NULL_YIELDS_NULL ON;".to_string(),
164        "SET QUOTED_IDENTIFIER ON;".to_string(),
165        "SET NUMERIC_ROUNDABORT OFF;".to_string(),
166        history_table_sql.to_string(),
167    ];
168
169    for migration in migrations {
170        let up_sql = fs::read_to_string(&migration.up_sql_path)
171            .map_err(|_| OrmError::new("failed to read migration up.sql"))?;
172        let checksum = checksum_hex(up_sql.as_bytes());
173        let statements = split_sql_statements(&up_sql);
174        let body = if statements.is_empty() {
175            String::new()
176        } else {
177            statements
178                .iter()
179                .map(|statement| format!("    EXEC(N'{}');", escape_sql_literal(statement)))
180                .collect::<Vec<_>>()
181                .join("\n")
182                + "\n"
183        };
184        script.push(render_idempotent_migration_block(
185            &migration.id,
186            &migration.name,
187            &checksum,
188            &body,
189        ));
190    }
191
192    Ok(script.join("\n\n"))
193}
194
195fn render_idempotent_migration_block(id: &str, name: &str, checksum: &str, body: &str) -> String {
196    format!(
197        "IF EXISTS (SELECT 1 FROM [dbo].[__sql_orm_migrations] WHERE [id] = N'{id}' AND [checksum] <> N'{checksum}')\nBEGIN\n    THROW 50001, N'sql-orm migration checksum mismatch for {id}', 1;\nEND\n\nIF NOT EXISTS (SELECT 1 FROM [dbo].[__sql_orm_migrations] WHERE [id] = N'{id}')\nBEGIN\n    BEGIN TRY\n        BEGIN TRANSACTION;\n{body}        INSERT INTO [dbo].[__sql_orm_migrations] ([id], [name], [checksum], [orm_version]) VALUES (N'{id}', N'{name}', N'{checksum}', N'{version}');\n        COMMIT TRANSACTION;\n    END TRY\n    BEGIN CATCH\n        IF XACT_STATE() <> 0\n            ROLLBACK TRANSACTION;\n        THROW;\n    END CATCH\nEND",
198        id = id,
199        name = name,
200        checksum = checksum,
201        version = ORM_VERSION,
202        body = body,
203    )
204}
205
206fn parse_migration_entry(path: PathBuf) -> Option<MigrationEntry> {
207    let file_name = path.file_name()?.to_str()?;
208    let (timestamp, slug) = file_name.split_once('_')?;
209    if timestamp.is_empty() || slug.is_empty() {
210        return None;
211    }
212
213    Some(MigrationEntry {
214        id: file_name.to_string(),
215        name: slug.replace('_', " "),
216        up_sql_path: path.join("up.sql"),
217        down_sql_path: path.join("down.sql"),
218        snapshot_path: path.join("model_snapshot.json"),
219        directory: path,
220    })
221}
222
223fn migration_timestamp() -> Result<String, OrmError> {
224    let duration = SystemTime::now()
225        .duration_since(UNIX_EPOCH)
226        .map_err(|_| OrmError::new("system clock is before UNIX_EPOCH"))?;
227    Ok(duration.as_nanos().to_string())
228}
229
230fn slugify(name: &str) -> String {
231    let mut slug = String::new();
232    let mut previous_was_separator = false;
233
234    for ch in name.chars() {
235        if ch.is_ascii_alphanumeric() {
236            slug.push(ch.to_ascii_lowercase());
237            previous_was_separator = false;
238        } else if !previous_was_separator {
239            slug.push('_');
240            previous_was_separator = true;
241        }
242    }
243
244    slug.trim_matches('_').to_string()
245}
246
247fn checksum_hex(bytes: &[u8]) -> String {
248    let mut hash = 0xcbf29ce484222325u64;
249    for byte in bytes {
250        hash ^= u64::from(*byte);
251        hash = hash.wrapping_mul(0x100000001b3);
252    }
253
254    format!("{hash:016x}")
255}
256
257fn escape_sql_literal(sql: &str) -> String {
258    sql.replace('\'', "''")
259}
260
261fn split_sql_statements(sql: &str) -> Vec<String> {
262    sql.split(';')
263        .map(str::trim)
264        .filter(|statement| !statement.is_empty())
265        .filter(|statement| {
266            statement.lines().any(|line| {
267                let trimmed = line.trim();
268                !trimmed.is_empty() && !trimmed.starts_with("--")
269            })
270        })
271        .map(|statement| format!("{statement};"))
272        .collect()
273}
274
275#[cfg(test)]
276mod tests {
277    use super::{
278        build_database_update_script, create_migration_scaffold,
279        create_migration_scaffold_with_snapshot, latest_migration, list_migrations,
280        read_latest_model_snapshot, read_model_snapshot, write_migration_down_sql,
281        write_migration_up_sql, write_model_snapshot,
282    };
283    use crate::{ModelSnapshot, SchemaSnapshot};
284    use std::fs;
285    use std::path::PathBuf;
286    use std::time::{SystemTime, UNIX_EPOCH};
287
288    fn temp_project_root() -> PathBuf {
289        let unique = SystemTime::now()
290            .duration_since(UNIX_EPOCH)
291            .unwrap()
292            .as_nanos();
293        let path = std::env::temp_dir().join(format!("sql_orm_migrate_{unique}"));
294        fs::create_dir_all(&path).unwrap();
295        path
296    }
297
298    #[test]
299    fn creates_scaffolded_migration_files() {
300        let root = temp_project_root();
301
302        let scaffold = create_migration_scaffold(&root, "Create Customers").unwrap();
303
304        assert!(scaffold.id.contains("create_customers"));
305        assert!(scaffold.up_sql_path().exists());
306        assert!(scaffold.down_sql_path().exists());
307        assert!(scaffold.snapshot_path().exists());
308        assert!(!scaffold.directory.join("migration.rs").exists());
309
310        assert_eq!(
311            fs::read_to_string(scaffold.up_sql_path()).unwrap(),
312            format!(
313                "-- Migration: {}\n-- SQL Server DDL for this migration.\n",
314                scaffold.id
315            )
316        );
317        assert_eq!(
318            fs::read_to_string(scaffold.down_sql_path()).unwrap(),
319            format!(
320                "-- Migration: {}\n-- Manual rollback SQL for this editable migration.\n-- The current MVP does not execute down.sql automatically.\n",
321                scaffold.id
322            )
323        );
324
325        let snapshot = read_model_snapshot(&scaffold.snapshot_path()).unwrap();
326        assert_eq!(snapshot, ModelSnapshot::default());
327    }
328
329    #[test]
330    fn writes_and_reads_model_snapshot_artifact() {
331        let root = temp_project_root();
332        let snapshot_path = root.join("model_snapshot.json");
333        let snapshot = ModelSnapshot::new(vec![SchemaSnapshot::new("sales", Vec::new())]);
334
335        write_model_snapshot(&snapshot_path, &snapshot).unwrap();
336
337        assert_eq!(read_model_snapshot(&snapshot_path).unwrap(), snapshot);
338    }
339
340    #[test]
341    fn writes_generated_down_sql_artifact() {
342        let root = temp_project_root();
343        let down_sql_path = root.join("down.sql");
344
345        write_migration_down_sql(
346            &down_sql_path,
347            &[
348                "DROP TABLE [sales].[customers]".to_string(),
349                "DROP SCHEMA [sales]".to_string(),
350            ],
351        )
352        .unwrap();
353
354        assert_eq!(
355            fs::read_to_string(down_sql_path).unwrap(),
356            "DROP TABLE [sales].[customers];\n\nDROP SCHEMA [sales];\n"
357        );
358    }
359
360    #[test]
361    fn creates_scaffold_with_provided_model_snapshot() {
362        let root = temp_project_root();
363        let snapshot = ModelSnapshot::new(vec![SchemaSnapshot::new("sales", Vec::new())]);
364
365        let scaffold =
366            create_migration_scaffold_with_snapshot(&root, "Create Sales", &snapshot).unwrap();
367
368        assert_eq!(
369            read_model_snapshot(&scaffold.snapshot_path()).unwrap(),
370            snapshot
371        );
372    }
373
374    #[test]
375    fn lists_migrations_in_sorted_order() {
376        let root = temp_project_root();
377        let migrations_dir = root.join("migrations");
378        fs::create_dir_all(migrations_dir.join("200_create_orders")).unwrap();
379        fs::create_dir_all(migrations_dir.join("100_create_customers")).unwrap();
380
381        let migrations = list_migrations(&root).unwrap();
382
383        assert_eq!(migrations.len(), 2);
384        assert_eq!(migrations[0].id, "100_create_customers");
385        assert_eq!(migrations[1].id, "200_create_orders");
386    }
387
388    #[test]
389    fn returns_latest_migration_in_lexical_order() {
390        let root = temp_project_root();
391        let migrations_dir = root.join("migrations");
392        fs::create_dir_all(migrations_dir.join("100_create_customers")).unwrap();
393        fs::create_dir_all(migrations_dir.join("200_create_orders")).unwrap();
394
395        let latest = latest_migration(&root).unwrap().unwrap();
396
397        assert_eq!(latest.id, "200_create_orders");
398    }
399
400    #[test]
401    fn reads_latest_model_snapshot_from_last_local_migration() {
402        let root = temp_project_root();
403        let older_dir = root.join("migrations/100_create_customers");
404        let newer_dir = root.join("migrations/200_create_orders");
405        fs::create_dir_all(&older_dir).unwrap();
406        fs::create_dir_all(&newer_dir).unwrap();
407        fs::write(older_dir.join("up.sql"), "-- noop").unwrap();
408        fs::write(older_dir.join("down.sql"), "-- noop").unwrap();
409        fs::write(
410            older_dir.join("model_snapshot.json"),
411            "{\n  \"schemas\": []\n}\n",
412        )
413        .unwrap();
414        fs::write(newer_dir.join("up.sql"), "-- noop").unwrap();
415        fs::write(newer_dir.join("down.sql"), "-- noop").unwrap();
416        fs::write(
417            newer_dir.join("model_snapshot.json"),
418            "{\n  \"schemas\": [\n    {\n      \"name\": \"sales\",\n      \"tables\": []\n    }\n  ]\n}\n",
419        )
420        .unwrap();
421
422        let (migration, snapshot) = read_latest_model_snapshot(&root).unwrap().unwrap();
423
424        assert_eq!(migration.id, "200_create_orders");
425        assert!(snapshot.schema("sales").is_some());
426    }
427
428    #[test]
429    fn builds_database_update_script_with_history_inserts() {
430        let root = temp_project_root();
431        let scaffold = create_migration_scaffold(&root, "Create Customers").unwrap();
432        fs::write(
433            scaffold.directory.join("up.sql"),
434            "CREATE SCHEMA [sales];\nCREATE TABLE [sales].[customers] ([id] bigint NOT NULL);",
435        )
436        .unwrap();
437
438        let script =
439            build_database_update_script(&root, "CREATE TABLE [dbo].[__sql_orm_migrations] (...);")
440                .unwrap();
441
442        assert!(script.contains("CREATE TABLE [dbo].[__sql_orm_migrations]"));
443        assert!(script.contains("SET ANSI_NULLS ON;"));
444        assert!(script.contains("SET QUOTED_IDENTIFIER ON;"));
445        assert!(script.contains("SET NUMERIC_ROUNDABORT OFF;"));
446        assert!(script.contains("IF NOT EXISTS (SELECT 1 FROM [dbo].[__sql_orm_migrations]"));
447        assert!(script.contains("IF EXISTS (SELECT 1 FROM [dbo].[__sql_orm_migrations]"));
448        assert!(script.contains("THROW 50001, N'sql-orm migration checksum mismatch"));
449        assert!(script.contains("BEGIN TRY"));
450        assert!(script.contains("BEGIN TRANSACTION;"));
451        assert!(script.contains("EXEC(N'CREATE SCHEMA [sales];');"));
452        assert!(
453            script.contains("EXEC(N'CREATE TABLE [sales].[customers] ([id] bigint NOT NULL);');")
454        );
455        assert!(script.contains("INSERT INTO [dbo].[__sql_orm_migrations]"));
456        assert!(script.contains("COMMIT TRANSACTION;"));
457        assert!(script.contains("ROLLBACK TRANSACTION;"));
458    }
459
460    #[test]
461    fn builds_database_update_script_without_empty_exec_blocks() {
462        let root = temp_project_root();
463        let scaffold = create_migration_scaffold(&root, "Noop").unwrap();
464        fs::write(
465            scaffold.directory.join("up.sql"),
466            "-- comment only migration\n\n-- still intentionally empty\n",
467        )
468        .unwrap();
469
470        let script =
471            build_database_update_script(&root, "CREATE TABLE [dbo].[__sql_orm_migrations] (...);")
472                .unwrap();
473
474        assert!(!script.contains("EXEC(N'');"));
475        assert!(script.contains("INSERT INTO [dbo].[__sql_orm_migrations]"));
476    }
477
478    #[test]
479    fn writes_up_sql_from_compiled_statements() {
480        let root = temp_project_root();
481        let up_sql_path = root.join("up.sql");
482
483        write_migration_up_sql(
484            &up_sql_path,
485            &[
486                "CREATE SCHEMA [sales]".to_string(),
487                "CREATE TABLE [sales].[customers] ([id] bigint NOT NULL)".to_string(),
488            ],
489        )
490        .unwrap();
491
492        let sql = fs::read_to_string(up_sql_path).unwrap();
493
494        assert_eq!(
495            sql,
496            "CREATE SCHEMA [sales];\n\nCREATE TABLE [sales].[customers] ([id] bigint NOT NULL);\n"
497        );
498    }
499
500    #[test]
501    fn writes_noop_up_sql_when_no_statements_exist() {
502        let root = temp_project_root();
503        let up_sql_path = root.join("up.sql");
504
505        write_migration_up_sql(&up_sql_path, &[]).unwrap();
506
507        assert_eq!(
508            fs::read_to_string(up_sql_path).unwrap(),
509            "-- No schema changes detected.\n"
510        );
511    }
512}