Skip to main content

rustio_admin/
migrations.rs

1//! Versioned SQL migrations for PostgreSQL. Transactional; the
2//! `rustio_migrations` tracking table records which versions have been
3//! applied.
4
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9use crate::orm::Db;
10
11pub struct MigrationFile {
12    pub version: i64,
13    pub name: String,
14    pub path: PathBuf,
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct ApplyOptions {
19    pub verbose: bool,
20}
21
22pub async fn apply(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<String>> {
23    apply_with(db, dir, ApplyOptions::default()).await
24}
25
26pub async fn apply_with(db: &Db, dir: impl AsRef<Path>, opts: ApplyOptions) -> Result<Vec<String>> {
27    ensure_tracking_table(db).await?;
28
29    let files = discover(dir.as_ref())?;
30    let already = applied_versions(db).await?;
31    let mut newly = Vec::new();
32
33    for file in files {
34        if already.contains(&file.version) {
35            continue;
36        }
37        if opts.verbose {
38            log::info!("applying migration {:04}_{}", file.version, file.name);
39        }
40
41        let sql = fs::read_to_string(&file.path)?;
42        let statements = split_statements(&sql);
43
44        let mut tx = db
45            .pool()
46            .begin()
47            .await
48            .map_err(|e| Error::Internal(format!("begin tx: {e}")))?;
49
50        for stmt in &statements {
51            let trimmed = stmt.trim();
52            if trimmed.is_empty() {
53                continue;
54            }
55            sqlx::query(trimmed)
56                .execute(&mut *tx)
57                .await
58                .map_err(|e| Error::Internal(format!("migration {} failed: {e}", file.name)))?;
59        }
60
61        sqlx::query(
62            "INSERT INTO rustio_migrations (version, name, applied_at)
63             VALUES ($1, $2, NOW())",
64        )
65        .bind(file.version)
66        .bind(&file.name)
67        .execute(&mut *tx)
68        .await
69        .map_err(|e| Error::Internal(format!("tracking insert: {e}")))?;
70
71        tx.commit()
72            .await
73            .map_err(|e| Error::Internal(format!("commit: {e}")))?;
74
75        newly.push(file.name.clone());
76    }
77
78    Ok(newly)
79}
80
81pub async fn applied_versions(db: &Db) -> Result<Vec<i64>> {
82    ensure_tracking_table(db).await?;
83    let rows =
84        sqlx::query_scalar::<_, i64>("SELECT version FROM rustio_migrations ORDER BY version ASC")
85            .fetch_all(db.pool())
86            .await?;
87    Ok(rows)
88}
89
90pub async fn status(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<(String, bool)>> {
91    let applied = applied_versions(db).await?;
92    let files = discover(dir.as_ref())?;
93    Ok(files
94        .into_iter()
95        .map(|f| {
96            (
97                format!("{:04}_{}", f.version, f.name),
98                applied.contains(&f.version),
99            )
100        })
101        .collect())
102}
103
104pub fn generate(dir: impl AsRef<Path>, name: &str) -> Result<PathBuf> {
105    let dir = dir.as_ref();
106    fs::create_dir_all(dir)?;
107    let existing = discover(dir).unwrap_or_default();
108    let next = existing.iter().map(|m| m.version).max().unwrap_or(0) + 1;
109    let filename = format!("{:04}_{}.sql", next, slugify(name));
110    let path = dir.join(filename);
111    fs::write(&path, format!("-- {}\n\n", name))?;
112    Ok(path)
113}
114
115fn discover(dir: &Path) -> Result<Vec<MigrationFile>> {
116    if !dir.exists() {
117        return Ok(Vec::new());
118    }
119    let mut out = Vec::new();
120    for entry in fs::read_dir(dir)? {
121        let entry = entry?;
122        let path = entry.path();
123        if path.extension().and_then(|s| s.to_str()) != Some("sql") {
124            continue;
125        }
126        let stem = match path.file_stem().and_then(|s| s.to_str()) {
127            Some(s) => s,
128            None => continue,
129        };
130        let (ver_part, name_part) = match stem.split_once('_') {
131            Some(p) => p,
132            None => continue,
133        };
134        let version: i64 = match ver_part.parse() {
135            Ok(n) => n,
136            Err(_) => continue,
137        };
138        out.push(MigrationFile {
139            version,
140            name: name_part.to_string(),
141            path,
142        });
143    }
144    out.sort_by_key(|m| m.version);
145    Ok(out)
146}
147
148async fn ensure_tracking_table(db: &Db) -> Result<()> {
149    sqlx::query(
150        "CREATE TABLE IF NOT EXISTS rustio_migrations (
151            version    BIGINT PRIMARY KEY,
152            name       TEXT NOT NULL,
153            applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
154        )",
155    )
156    .execute(db.pool())
157    .await?;
158    Ok(())
159}
160
161/// Split a multi-statement SQL file on `;`, but not on `;` inside
162/// quoted strings, dollar-quoted bodies (Postgres PL/pgSQL), or
163/// `--` / `/* */` comments.
164fn split_statements(sql: &str) -> Vec<String> {
165    let mut out = Vec::new();
166    let mut current = String::new();
167    let mut chars = sql.chars().peekable();
168    let mut in_string = false;
169    let mut in_dollar = false;
170    let mut dollar_tag = String::new();
171    let mut in_line_comment = false;
172    let mut in_block_comment = false;
173
174    while let Some(c) = chars.next() {
175        if in_line_comment {
176            current.push(c);
177            if c == '\n' {
178                in_line_comment = false;
179            }
180            continue;
181        }
182        if in_block_comment {
183            current.push(c);
184            if c == '*' && chars.peek() == Some(&'/') {
185                current.push(chars.next().unwrap());
186                in_block_comment = false;
187            }
188            continue;
189        }
190        if in_dollar {
191            current.push(c);
192            if c == '$' {
193                let rest: String = chars.clone().take(dollar_tag.len()).collect();
194                if rest == dollar_tag {
195                    for _ in 0..dollar_tag.len() {
196                        current.push(chars.next().unwrap());
197                    }
198                    in_dollar = false;
199                    dollar_tag.clear();
200                }
201            }
202            continue;
203        }
204        if in_string {
205            current.push(c);
206            if c == '\'' {
207                if chars.peek() == Some(&'\'') {
208                    current.push(chars.next().unwrap());
209                } else {
210                    in_string = false;
211                }
212            }
213            continue;
214        }
215
216        match c {
217            '\'' => {
218                in_string = true;
219                current.push(c);
220            }
221            '-' if chars.peek() == Some(&'-') => {
222                in_line_comment = true;
223                current.push(c);
224            }
225            '/' if chars.peek() == Some(&'*') => {
226                in_block_comment = true;
227                current.push(c);
228            }
229            '$' => {
230                let mut tag = String::from("$");
231                let mut clone = chars.clone();
232                while let Some(&nc) = clone.peek() {
233                    if nc == '$' {
234                        tag.push('$');
235                        break;
236                    }
237                    if nc.is_alphanumeric() || nc == '_' {
238                        tag.push(nc);
239                        clone.next();
240                    } else {
241                        break;
242                    }
243                }
244                if tag.ends_with('$') && tag.len() >= 2 {
245                    for _ in 1..tag.len() {
246                        current.push(chars.next().unwrap());
247                    }
248                    current.insert(current.len() - tag.len() + 1, '$');
249                    current.push('$');
250                    dollar_tag = tag;
251                    in_dollar = true;
252                } else {
253                    current.push(c);
254                }
255            }
256            ';' => {
257                out.push(std::mem::take(&mut current));
258            }
259            other => current.push(other),
260        }
261    }
262
263    if !current.trim().is_empty() {
264        out.push(current);
265    }
266    out
267}
268
269fn slugify(name: &str) -> String {
270    name.chars()
271        .map(|c| {
272            if c.is_alphanumeric() {
273                c.to_ascii_lowercase()
274            } else {
275                '_'
276            }
277        })
278        .collect()
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn split_ignores_semicolon_in_string() {
287        let sql = "INSERT INTO t VALUES ('a;b'); SELECT 1;";
288        let parts = split_statements(sql);
289        assert_eq!(parts.len(), 2);
290    }
291
292    #[test]
293    fn split_ignores_line_comments() {
294        let sql = "SELECT 1; -- comment with ;\nSELECT 2;";
295        let parts = split_statements(sql);
296        assert_eq!(parts.len(), 2);
297    }
298
299    #[test]
300    fn slugify_lowercases_and_replaces() {
301        assert_eq!(slugify("Add Users Table!"), "add_users_table_");
302    }
303}