soar_core/database/
migration.rs

1use include_dir::Dir;
2use rusqlite::Connection;
3
4use crate::{constants::NESTS_MIGRATIONS_DIR, error::SoarError, SoarResult};
5
6pub struct Migration {
7    version: i32,
8    sql: String,
9}
10
11pub struct MigrationManager {
12    conn: Connection,
13}
14
15impl MigrationManager {
16    pub fn new(conn: Connection) -> rusqlite::Result<Self> {
17        Ok(Self { conn })
18    }
19
20    fn get_current_version(&self) -> rusqlite::Result<i32> {
21        self.conn
22            .query_row("PRAGMA user_version", [], |row| row.get(0))
23    }
24
25    fn run_migration(&mut self, migration: &Migration) -> rusqlite::Result<()> {
26        let tx = self.conn.transaction()?;
27
28        match tx.execute_batch(&migration.sql) {
29            Ok(_) => {
30                tx.pragma_update(None, "user_version", migration.version)?;
31                tx.commit()?;
32                Ok(())
33            }
34            Err(err) => Err(err),
35        }
36    }
37
38    fn load_migrations_from_dir(dir: Dir) -> SoarResult<Vec<Migration>> {
39        let mut migrations = Vec::new();
40
41        for entry in dir.files() {
42            let path = entry.path();
43
44            if path.extension().and_then(|s| s.to_str()) == Some("sql") {
45                let filename = path
46                    .file_stem()
47                    .and_then(|s| s.to_str())
48                    .ok_or_else(|| SoarError::Custom("Invalid filename".into()))?;
49
50                if !filename.starts_with('V') {
51                    continue;
52                }
53
54                let parts: Vec<&str> = filename[1..].splitn(2, '_').collect();
55                if parts.len() != 2 {
56                    continue;
57                }
58
59                let version = parts[0].parse::<i32>().map_err(|_| {
60                    SoarError::Custom(format!("Invalid version number in filename: {filename}"))
61                })?;
62
63                let sql = entry.contents_utf8().unwrap().to_string();
64
65                migrations.push(Migration { version, sql });
66            }
67        }
68
69        migrations.sort_by_key(|m| m.version);
70
71        let mut expected_version = 1;
72        for migration in &migrations {
73            if migration.version != expected_version {
74                return Err(SoarError::Custom(format!(
75                    "Invalid migration sequence. Expected version {}, found {}",
76                    expected_version, migration.version
77                )));
78            }
79            expected_version += 1;
80        }
81
82        Ok(migrations)
83    }
84
85    pub fn migrate_from_dir(&mut self, dir: Dir) -> SoarResult<()> {
86        let migrations = Self::load_migrations_from_dir(dir)?;
87        let current_version = self.get_current_version()?;
88
89        let pending: Vec<&Migration> = migrations
90            .iter()
91            .filter(|m| m.version > current_version)
92            .collect();
93
94        for migration in pending {
95            self.run_migration(migration)?;
96        }
97
98        Ok(())
99    }
100}
101
102pub fn run_nests(conn: Connection) -> SoarResult<()> {
103    let mut manager = MigrationManager::new(conn)?;
104    manager.migrate_from_dir(NESTS_MIGRATIONS_DIR)?;
105    Ok(())
106}