Skip to main content

rustio_admin/
migrations.rs

1//! Versioned SQL migrations for PostgreSQL. Transactional; the
2//! `rustio_migrations` tracking table records which versions have been
3//! applied.
4
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9use crate::orm::Db;
10
11// public:
12pub struct MigrationFile {
13    pub version: i64,
14    pub name: String,
15    pub path: PathBuf,
16}
17
18// public:
19#[derive(Debug, Clone, Default)]
20pub struct ApplyOptions {
21    pub verbose: bool,
22}
23
24// public:
25pub async fn apply(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<String>> {
26    apply_with(db, dir, ApplyOptions::default()).await
27}
28
29// public:
30pub async fn apply_with(db: &Db, dir: impl AsRef<Path>, opts: ApplyOptions) -> Result<Vec<String>> {
31    ensure_tracking_table(db).await?;
32
33    let files = discover(dir.as_ref())?;
34    let already = applied_versions(db).await?;
35    let mut newly = Vec::new();
36
37    for file in files {
38        if already.contains(&file.version) {
39            continue;
40        }
41        if opts.verbose {
42            log::info!("applying migration {:04}_{}", file.version, file.name);
43        }
44
45        let sql = fs::read_to_string(&file.path)?;
46        let statements = split_statements(&sql);
47
48        let mut tx = db
49            .pool()
50            .begin()
51            .await
52            .map_err(|e| Error::Internal(format!("begin tx: {e}")))?;
53
54        for stmt in &statements {
55            let trimmed = stmt.trim();
56            if trimmed.is_empty() {
57                continue;
58            }
59            sqlx::query(trimmed)
60                .execute(&mut *tx)
61                .await
62                .map_err(|e| Error::Internal(format!("migration {} failed: {e}", file.name)))?;
63        }
64
65        sqlx::query(
66            "INSERT INTO rustio_migrations (version, name, applied_at)
67             VALUES ($1, $2, NOW())",
68        )
69        .bind(file.version)
70        .bind(&file.name)
71        .execute(&mut *tx)
72        .await
73        .map_err(|e| Error::Internal(format!("tracking insert: {e}")))?;
74
75        tx.commit()
76            .await
77            .map_err(|e| Error::Internal(format!("commit: {e}")))?;
78
79        newly.push(file.name.clone());
80    }
81
82    Ok(newly)
83}
84
85// public:
86pub async fn applied_versions(db: &Db) -> Result<Vec<i64>> {
87    ensure_tracking_table(db).await?;
88    let rows =
89        sqlx::query_scalar::<_, i64>("SELECT version FROM rustio_migrations ORDER BY version ASC")
90            .fetch_all(db.pool())
91            .await?;
92    Ok(rows)
93}
94
95// public:
96pub async fn status(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<(String, bool)>> {
97    let applied = applied_versions(db).await?;
98    let files = discover(dir.as_ref())?;
99    Ok(files
100        .into_iter()
101        .map(|f| {
102            (
103                format!("{:04}_{}", f.version, f.name),
104                applied.contains(&f.version),
105            )
106        })
107        .collect())
108}
109
110// public:
111pub fn generate(dir: impl AsRef<Path>, name: &str) -> Result<PathBuf> {
112    let dir = dir.as_ref();
113    fs::create_dir_all(dir)?;
114    let existing = discover(dir).unwrap_or_default();
115    let next = existing.iter().map(|m| m.version).max().unwrap_or(0) + 1;
116    let filename = format!("{:04}_{}.sql", next, slugify(name));
117    let path = dir.join(filename);
118    fs::write(&path, format!("-- {name}\n\n"))?;
119    Ok(path)
120}
121
122fn discover(dir: &Path) -> Result<Vec<MigrationFile>> {
123    if !dir.exists() {
124        return Ok(Vec::new());
125    }
126    let mut out = Vec::new();
127    for entry in fs::read_dir(dir)? {
128        let entry = entry?;
129        let path = entry.path();
130        if path.extension().and_then(|s| s.to_str()) != Some("sql") {
131            continue;
132        }
133        let stem = match path.file_stem().and_then(|s| s.to_str()) {
134            Some(s) => s,
135            None => continue,
136        };
137        let (ver_part, name_part) = match stem.split_once('_') {
138            Some(p) => p,
139            None => continue,
140        };
141        let version: i64 = match ver_part.parse() {
142            Ok(n) => n,
143            Err(_) => continue,
144        };
145        out.push(MigrationFile {
146            version,
147            name: name_part.to_string(),
148            path,
149        });
150    }
151    out.sort_by_key(|m| m.version);
152    Ok(out)
153}
154
155async fn ensure_tracking_table(db: &Db) -> Result<()> {
156    sqlx::query(
157        "CREATE TABLE IF NOT EXISTS rustio_migrations (
158            version    BIGINT PRIMARY KEY,
159            name       TEXT NOT NULL,
160            applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
161        )",
162    )
163    .execute(db.pool())
164    .await?;
165    Ok(())
166}
167
168/// Split a multi-statement SQL file on `;`, but not on `;` inside
169/// quoted strings, dollar-quoted bodies (Postgres PL/pgSQL), or
170/// `--` / `/* */` comments.
171fn split_statements(sql: &str) -> Vec<String> {
172    let mut out = Vec::new();
173    let mut current = String::new();
174    let mut chars = sql.chars().peekable();
175    let mut in_string = false;
176    let mut in_dollar = false;
177    let mut dollar_tag = String::new();
178    let mut in_line_comment = false;
179    let mut in_block_comment = false;
180
181    while let Some(c) = chars.next() {
182        if in_line_comment {
183            current.push(c);
184            if c == '\n' {
185                in_line_comment = false;
186            }
187            continue;
188        }
189        if in_block_comment {
190            current.push(c);
191            if c == '*' && chars.peek() == Some(&'/') {
192                current.push(chars.next().unwrap());
193                in_block_comment = false;
194            }
195            continue;
196        }
197        if in_dollar {
198            current.push(c);
199            if c == '$' {
200                let rest: String = chars.clone().take(dollar_tag.len()).collect();
201                if rest == dollar_tag {
202                    for _ in 0..dollar_tag.len() {
203                        current.push(chars.next().unwrap());
204                    }
205                    in_dollar = false;
206                    dollar_tag.clear();
207                }
208            }
209            continue;
210        }
211        if in_string {
212            current.push(c);
213            if c == '\'' {
214                if chars.peek() == Some(&'\'') {
215                    current.push(chars.next().unwrap());
216                } else {
217                    in_string = false;
218                }
219            }
220            continue;
221        }
222
223        match c {
224            '\'' => {
225                in_string = true;
226                current.push(c);
227            }
228            '-' if chars.peek() == Some(&'-') => {
229                in_line_comment = true;
230                current.push(c);
231            }
232            '/' if chars.peek() == Some(&'*') => {
233                in_block_comment = true;
234                current.push(c);
235            }
236            '$' => {
237                let mut tag = String::from("$");
238                let mut clone = chars.clone();
239                while let Some(&nc) = clone.peek() {
240                    if nc == '$' {
241                        tag.push('$');
242                        break;
243                    }
244                    if nc.is_alphanumeric() || nc == '_' {
245                        tag.push(nc);
246                        clone.next();
247                    } else {
248                        break;
249                    }
250                }
251                if tag.ends_with('$') && tag.len() >= 2 {
252                    for _ in 1..tag.len() {
253                        current.push(chars.next().unwrap());
254                    }
255                    current.insert(current.len() - tag.len() + 1, '$');
256                    current.push('$');
257                    dollar_tag = tag;
258                    in_dollar = true;
259                } else {
260                    current.push(c);
261                }
262            }
263            ';' => {
264                out.push(std::mem::take(&mut current));
265            }
266            other => current.push(other),
267        }
268    }
269
270    if !current.trim().is_empty() {
271        out.push(current);
272    }
273    out
274}
275
276fn slugify(name: &str) -> String {
277    name.chars()
278        .map(|c| {
279            if c.is_alphanumeric() {
280                c.to_ascii_lowercase()
281            } else {
282                '_'
283            }
284        })
285        .collect()
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn split_ignores_semicolon_in_string() {
294        let sql = "INSERT INTO t VALUES ('a;b'); SELECT 1;";
295        let parts = split_statements(sql);
296        assert_eq!(parts.len(), 2);
297    }
298
299    #[test]
300    fn split_ignores_line_comments() {
301        let sql = "SELECT 1; -- comment with ;\nSELECT 2;";
302        let parts = split_statements(sql);
303        assert_eq!(parts.len(), 2);
304    }
305
306    #[test]
307    fn slugify_lowercases_and_replaces() {
308        assert_eq!(slugify("Add Users Table!"), "add_users_table_");
309    }
310}