rustio_core/
migrations.rs1use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9use crate::orm::Db;
10
11pub struct MigrationFile {
12 pub version: i64,
13 pub name: String,
14 pub path: PathBuf,
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct ApplyOptions {
19 pub verbose: bool,
20}
21
22pub async fn apply(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<String>> {
23 apply_with(db, dir, ApplyOptions::default()).await
24}
25
26pub async fn apply_with(db: &Db, dir: impl AsRef<Path>, opts: ApplyOptions) -> Result<Vec<String>> {
27 ensure_tracking_table(db).await?;
28
29 let files = discover(dir.as_ref())?;
30 let already = applied_versions(db).await?;
31 let mut newly = Vec::new();
32
33 for file in files {
34 if already.contains(&file.version) {
35 continue;
36 }
37 if opts.verbose {
38 log::info!("applying migration {:04}_{}", file.version, file.name);
39 }
40
41 let sql = fs::read_to_string(&file.path)?;
42 let statements = split_statements(&sql);
43
44 let mut tx = db
45 .pool()
46 .begin()
47 .await
48 .map_err(|e| Error::Internal(format!("begin tx: {e}")))?;
49
50 for stmt in &statements {
51 let trimmed = stmt.trim();
52 if trimmed.is_empty() {
53 continue;
54 }
55 sqlx::query(trimmed)
56 .execute(&mut *tx)
57 .await
58 .map_err(|e| Error::Internal(format!("migration {} failed: {e}", file.name)))?;
59 }
60
61 sqlx::query(
62 "INSERT INTO rustio_migrations (version, name, applied_at)
63 VALUES ($1, $2, NOW())",
64 )
65 .bind(file.version)
66 .bind(&file.name)
67 .execute(&mut *tx)
68 .await
69 .map_err(|e| Error::Internal(format!("tracking insert: {e}")))?;
70
71 tx.commit()
72 .await
73 .map_err(|e| Error::Internal(format!("commit: {e}")))?;
74
75 newly.push(file.name.clone());
76 }
77
78 Ok(newly)
79}
80
81pub async fn applied_versions(db: &Db) -> Result<Vec<i64>> {
82 ensure_tracking_table(db).await?;
83 let rows = sqlx::query_scalar::<_, i64>(
84 "SELECT version FROM rustio_migrations ORDER BY version ASC",
85 )
86 .fetch_all(db.pool())
87 .await?;
88 Ok(rows)
89}
90
91pub async fn status(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<(String, bool)>> {
92 let applied = applied_versions(db).await?;
93 let files = discover(dir.as_ref())?;
94 Ok(files
95 .into_iter()
96 .map(|f| {
97 (
98 format!("{:04}_{}", f.version, f.name),
99 applied.contains(&f.version),
100 )
101 })
102 .collect())
103}
104
105pub fn generate(dir: impl AsRef<Path>, name: &str) -> Result<PathBuf> {
106 let dir = dir.as_ref();
107 fs::create_dir_all(dir)?;
108 let existing = discover(dir).unwrap_or_default();
109 let next = existing.iter().map(|m| m.version).max().unwrap_or(0) + 1;
110 let filename = format!("{:04}_{}.sql", next, slugify(name));
111 let path = dir.join(filename);
112 fs::write(&path, format!("-- {}\n\n", name))?;
113 Ok(path)
114}
115
116fn discover(dir: &Path) -> Result<Vec<MigrationFile>> {
117 if !dir.exists() {
118 return Ok(Vec::new());
119 }
120 let mut out = Vec::new();
121 for entry in fs::read_dir(dir)? {
122 let entry = entry?;
123 let path = entry.path();
124 if path.extension().and_then(|s| s.to_str()) != Some("sql") {
125 continue;
126 }
127 let stem = match path.file_stem().and_then(|s| s.to_str()) {
128 Some(s) => s,
129 None => continue,
130 };
131 let (ver_part, name_part) = match stem.split_once('_') {
132 Some(p) => p,
133 None => continue,
134 };
135 let version: i64 = match ver_part.parse() {
136 Ok(n) => n,
137 Err(_) => continue,
138 };
139 out.push(MigrationFile {
140 version,
141 name: name_part.to_string(),
142 path,
143 });
144 }
145 out.sort_by_key(|m| m.version);
146 Ok(out)
147}
148
149async fn ensure_tracking_table(db: &Db) -> Result<()> {
150 sqlx::query(
151 "CREATE TABLE IF NOT EXISTS rustio_migrations (
152 version BIGINT PRIMARY KEY,
153 name TEXT NOT NULL,
154 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
155 )",
156 )
157 .execute(db.pool())
158 .await?;
159 Ok(())
160}
161
162fn split_statements(sql: &str) -> Vec<String> {
166 let mut out = Vec::new();
167 let mut current = String::new();
168 let mut chars = sql.chars().peekable();
169 let mut in_string = false;
170 let mut in_dollar = false;
171 let mut dollar_tag = String::new();
172 let mut in_line_comment = false;
173 let mut in_block_comment = false;
174
175 while let Some(c) = chars.next() {
176 if in_line_comment {
177 current.push(c);
178 if c == '\n' {
179 in_line_comment = false;
180 }
181 continue;
182 }
183 if in_block_comment {
184 current.push(c);
185 if c == '*' && chars.peek() == Some(&'/') {
186 current.push(chars.next().unwrap());
187 in_block_comment = false;
188 }
189 continue;
190 }
191 if in_dollar {
192 current.push(c);
193 if c == '$' {
195 let rest: String = chars.clone().take(dollar_tag.len()).collect();
196 if rest == dollar_tag {
197 for _ in 0..dollar_tag.len() {
198 current.push(chars.next().unwrap());
199 }
200 in_dollar = false;
201 dollar_tag.clear();
202 }
203 }
204 continue;
205 }
206 if in_string {
207 current.push(c);
208 if c == '\'' {
209 if chars.peek() == Some(&'\'') {
210 current.push(chars.next().unwrap());
211 } else {
212 in_string = false;
213 }
214 }
215 continue;
216 }
217
218 match c {
219 '\'' => {
220 in_string = true;
221 current.push(c);
222 }
223 '-' if chars.peek() == Some(&'-') => {
224 in_line_comment = true;
225 current.push(c);
226 }
227 '/' if chars.peek() == Some(&'*') => {
228 in_block_comment = true;
229 current.push(c);
230 }
231 '$' => {
232 let mut tag = String::from("$");
234 let mut clone = chars.clone();
235 while let Some(&nc) = clone.peek() {
236 if nc == '$' {
237 tag.push('$');
238 break;
239 }
240 if nc.is_alphanumeric() || nc == '_' {
241 tag.push(nc);
242 clone.next();
243 } else {
244 break;
245 }
246 }
247 if tag.ends_with('$') && tag.len() >= 2 {
248 for _ in 1..tag.len() {
250 current.push(chars.next().unwrap());
251 }
252 current.insert(current.len() - tag.len() + 1, '$');
253 current.push('$');
255 dollar_tag = tag;
256 in_dollar = true;
257 } else {
258 current.push(c);
259 }
260 }
261 ';' => {
262 out.push(std::mem::take(&mut current));
263 }
264 other => current.push(other),
265 }
266 }
267
268 if !current.trim().is_empty() {
269 out.push(current);
270 }
271 out
272}
273
274fn slugify(name: &str) -> String {
275 name.chars()
276 .map(|c| if c.is_alphanumeric() { c.to_ascii_lowercase() } else { '_' })
277 .collect()
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn split_ignores_semicolon_in_string() {
286 let sql = "INSERT INTO t VALUES ('a;b'); SELECT 1;";
287 let parts = split_statements(sql);
288 assert_eq!(parts.len(), 2);
289 }
290
291 #[test]
292 fn split_ignores_line_comments() {
293 let sql = "SELECT 1; -- comment with ;\nSELECT 2;";
294 let parts = split_statements(sql);
295 assert_eq!(parts.len(), 2);
296 }
297
298 #[test]
299 fn slugify_lowercases_and_replaces() {
300 assert_eq!(slugify("Add Users Table!"), "add_users_table_");
301 }
302}