Skip to main content

rustio_core/
migrations.rs

1//! Versioned SQL migrations for PostgreSQL. Transactional where
2//! possible; the tracking table records which versions have been
3//! applied.
4
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9use crate::orm::Db;
10
11pub struct MigrationFile {
12    pub version: i64,
13    pub name: String,
14    pub path: PathBuf,
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct ApplyOptions {
19    pub verbose: bool,
20}
21
22pub async fn apply(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<String>> {
23    apply_with(db, dir, ApplyOptions::default()).await
24}
25
26pub async fn apply_with(db: &Db, dir: impl AsRef<Path>, opts: ApplyOptions) -> Result<Vec<String>> {
27    ensure_tracking_table(db).await?;
28
29    let files = discover(dir.as_ref())?;
30    let already = applied_versions(db).await?;
31    let mut newly = Vec::new();
32
33    for file in files {
34        if already.contains(&file.version) {
35            continue;
36        }
37        if opts.verbose {
38            log::info!("applying migration {:04}_{}", file.version, file.name);
39        }
40
41        let sql = fs::read_to_string(&file.path)?;
42        let statements = split_statements(&sql);
43
44        let mut tx = db
45            .pool()
46            .begin()
47            .await
48            .map_err(|e| Error::Internal(format!("begin tx: {e}")))?;
49
50        for stmt in &statements {
51            let trimmed = stmt.trim();
52            if trimmed.is_empty() {
53                continue;
54            }
55            sqlx::query(trimmed)
56                .execute(&mut *tx)
57                .await
58                .map_err(|e| Error::Internal(format!("migration {} failed: {e}", file.name)))?;
59        }
60
61        sqlx::query(
62            "INSERT INTO rustio_migrations (version, name, applied_at)
63             VALUES ($1, $2, NOW())",
64        )
65        .bind(file.version)
66        .bind(&file.name)
67        .execute(&mut *tx)
68        .await
69        .map_err(|e| Error::Internal(format!("tracking insert: {e}")))?;
70
71        tx.commit()
72            .await
73            .map_err(|e| Error::Internal(format!("commit: {e}")))?;
74
75        newly.push(file.name.clone());
76    }
77
78    Ok(newly)
79}
80
81pub async fn applied_versions(db: &Db) -> Result<Vec<i64>> {
82    ensure_tracking_table(db).await?;
83    let rows = sqlx::query_scalar::<_, i64>(
84        "SELECT version FROM rustio_migrations ORDER BY version ASC",
85    )
86    .fetch_all(db.pool())
87    .await?;
88    Ok(rows)
89}
90
91pub async fn status(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<(String, bool)>> {
92    let applied = applied_versions(db).await?;
93    let files = discover(dir.as_ref())?;
94    Ok(files
95        .into_iter()
96        .map(|f| {
97            (
98                format!("{:04}_{}", f.version, f.name),
99                applied.contains(&f.version),
100            )
101        })
102        .collect())
103}
104
105pub fn generate(dir: impl AsRef<Path>, name: &str) -> Result<PathBuf> {
106    let dir = dir.as_ref();
107    fs::create_dir_all(dir)?;
108    let existing = discover(dir).unwrap_or_default();
109    let next = existing.iter().map(|m| m.version).max().unwrap_or(0) + 1;
110    let filename = format!("{:04}_{}.sql", next, slugify(name));
111    let path = dir.join(filename);
112    fs::write(&path, format!("-- {}\n\n", name))?;
113    Ok(path)
114}
115
116fn discover(dir: &Path) -> Result<Vec<MigrationFile>> {
117    if !dir.exists() {
118        return Ok(Vec::new());
119    }
120    let mut out = Vec::new();
121    for entry in fs::read_dir(dir)? {
122        let entry = entry?;
123        let path = entry.path();
124        if path.extension().and_then(|s| s.to_str()) != Some("sql") {
125            continue;
126        }
127        let stem = match path.file_stem().and_then(|s| s.to_str()) {
128            Some(s) => s,
129            None => continue,
130        };
131        let (ver_part, name_part) = match stem.split_once('_') {
132            Some(p) => p,
133            None => continue,
134        };
135        let version: i64 = match ver_part.parse() {
136            Ok(n) => n,
137            Err(_) => continue,
138        };
139        out.push(MigrationFile {
140            version,
141            name: name_part.to_string(),
142            path,
143        });
144    }
145    out.sort_by_key(|m| m.version);
146    Ok(out)
147}
148
149async fn ensure_tracking_table(db: &Db) -> Result<()> {
150    sqlx::query(
151        "CREATE TABLE IF NOT EXISTS rustio_migrations (
152            version    BIGINT PRIMARY KEY,
153            name       TEXT NOT NULL,
154            applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
155        )",
156    )
157    .execute(db.pool())
158    .await?;
159    Ok(())
160}
161
162/// Split a multi-statement SQL file on `;`, but not on `;` inside
163/// quoted strings, dollar-quoted bodies (Postgres PL/pgSQL), or
164/// `--` / `/* */` comments.
165fn split_statements(sql: &str) -> Vec<String> {
166    let mut out = Vec::new();
167    let mut current = String::new();
168    let mut chars = sql.chars().peekable();
169    let mut in_string = false;
170    let mut in_dollar = false;
171    let mut dollar_tag = String::new();
172    let mut in_line_comment = false;
173    let mut in_block_comment = false;
174
175    while let Some(c) = chars.next() {
176        if in_line_comment {
177            current.push(c);
178            if c == '\n' {
179                in_line_comment = false;
180            }
181            continue;
182        }
183        if in_block_comment {
184            current.push(c);
185            if c == '*' && chars.peek() == Some(&'/') {
186                current.push(chars.next().unwrap());
187                in_block_comment = false;
188            }
189            continue;
190        }
191        if in_dollar {
192            current.push(c);
193            // Look for closing dollar tag.
194            if c == '$' {
195                let rest: String = chars.clone().take(dollar_tag.len()).collect();
196                if rest == dollar_tag {
197                    for _ in 0..dollar_tag.len() {
198                        current.push(chars.next().unwrap());
199                    }
200                    in_dollar = false;
201                    dollar_tag.clear();
202                }
203            }
204            continue;
205        }
206        if in_string {
207            current.push(c);
208            if c == '\'' {
209                if chars.peek() == Some(&'\'') {
210                    current.push(chars.next().unwrap());
211                } else {
212                    in_string = false;
213                }
214            }
215            continue;
216        }
217
218        match c {
219            '\'' => {
220                in_string = true;
221                current.push(c);
222            }
223            '-' if chars.peek() == Some(&'-') => {
224                in_line_comment = true;
225                current.push(c);
226            }
227            '/' if chars.peek() == Some(&'*') => {
228                in_block_comment = true;
229                current.push(c);
230            }
231            '$' => {
232                // Potential dollar-quoted body: $tag$...$tag$
233                let mut tag = String::from("$");
234                let mut clone = chars.clone();
235                while let Some(&nc) = clone.peek() {
236                    if nc == '$' {
237                        tag.push('$');
238                        break;
239                    }
240                    if nc.is_alphanumeric() || nc == '_' {
241                        tag.push(nc);
242                        clone.next();
243                    } else {
244                        break;
245                    }
246                }
247                if tag.ends_with('$') && tag.len() >= 2 {
248                    // Consume the characters we peeked.
249                    for _ in 1..tag.len() {
250                        current.push(chars.next().unwrap());
251                    }
252                    current.insert(current.len() - tag.len() + 1, '$');
253                    // actually push the opening $
254                    current.push('$');
255                    dollar_tag = tag;
256                    in_dollar = true;
257                } else {
258                    current.push(c);
259                }
260            }
261            ';' => {
262                out.push(std::mem::take(&mut current));
263            }
264            other => current.push(other),
265        }
266    }
267
268    if !current.trim().is_empty() {
269        out.push(current);
270    }
271    out
272}
273
274fn slugify(name: &str) -> String {
275    name.chars()
276        .map(|c| if c.is_alphanumeric() { c.to_ascii_lowercase() } else { '_' })
277        .collect()
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn split_ignores_semicolon_in_string() {
286        let sql = "INSERT INTO t VALUES ('a;b'); SELECT 1;";
287        let parts = split_statements(sql);
288        assert_eq!(parts.len(), 2);
289    }
290
291    #[test]
292    fn split_ignores_line_comments() {
293        let sql = "SELECT 1; -- comment with ;\nSELECT 2;";
294        let parts = split_statements(sql);
295        assert_eq!(parts.len(), 2);
296    }
297
298    #[test]
299    fn slugify_lowercases_and_replaces() {
300        assert_eq!(slugify("Add Users Table!"), "add_users_table_");
301    }
302}