rustio_admin/
migrations.rs1use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9use crate::orm::Db;
10
11pub struct MigrationFile {
13 pub version: i64,
14 pub name: String,
15 pub path: PathBuf,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct ApplyOptions {
21 pub verbose: bool,
22}
23
24pub async fn apply(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<String>> {
26 apply_with(db, dir, ApplyOptions::default()).await
27}
28
29pub async fn apply_with(db: &Db, dir: impl AsRef<Path>, opts: ApplyOptions) -> Result<Vec<String>> {
31 ensure_tracking_table(db).await?;
32
33 let files = discover(dir.as_ref())?;
34 let already = applied_versions(db).await?;
35 let mut newly = Vec::new();
36
37 for file in files {
38 if already.contains(&file.version) {
39 continue;
40 }
41 if opts.verbose {
42 log::info!("applying migration {:04}_{}", file.version, file.name);
43 }
44
45 let sql = fs::read_to_string(&file.path)?;
46 let statements = split_statements(&sql);
47
48 let mut tx = db
49 .pool()
50 .begin()
51 .await
52 .map_err(|e| Error::Internal(format!("begin tx: {e}")))?;
53
54 for stmt in &statements {
55 let trimmed = stmt.trim();
56 if trimmed.is_empty() {
57 continue;
58 }
59 sqlx::query(trimmed)
60 .execute(&mut *tx)
61 .await
62 .map_err(|e| Error::Internal(format!("migration {} failed: {e}", file.name)))?;
63 }
64
65 sqlx::query(
66 "INSERT INTO rustio_migrations (version, name, applied_at)
67 VALUES ($1, $2, NOW())",
68 )
69 .bind(file.version)
70 .bind(&file.name)
71 .execute(&mut *tx)
72 .await
73 .map_err(|e| Error::Internal(format!("tracking insert: {e}")))?;
74
75 tx.commit()
76 .await
77 .map_err(|e| Error::Internal(format!("commit: {e}")))?;
78
79 newly.push(file.name.clone());
80 }
81
82 Ok(newly)
83}
84
85pub async fn applied_versions(db: &Db) -> Result<Vec<i64>> {
87 ensure_tracking_table(db).await?;
88 let rows =
89 sqlx::query_scalar::<_, i64>("SELECT version FROM rustio_migrations ORDER BY version ASC")
90 .fetch_all(db.pool())
91 .await?;
92 Ok(rows)
93}
94
95pub async fn status(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<(String, bool)>> {
97 let applied = applied_versions(db).await?;
98 let files = discover(dir.as_ref())?;
99 Ok(files
100 .into_iter()
101 .map(|f| {
102 (
103 format!("{:04}_{}", f.version, f.name),
104 applied.contains(&f.version),
105 )
106 })
107 .collect())
108}
109
110pub fn generate(dir: impl AsRef<Path>, name: &str) -> Result<PathBuf> {
112 let dir = dir.as_ref();
113 fs::create_dir_all(dir)?;
114 let existing = discover(dir).unwrap_or_default();
115 let next = existing.iter().map(|m| m.version).max().unwrap_or(0) + 1;
116 let filename = format!("{:04}_{}.sql", next, slugify(name));
117 let path = dir.join(filename);
118 fs::write(&path, format!("-- {name}\n\n"))?;
119 Ok(path)
120}
121
122fn discover(dir: &Path) -> Result<Vec<MigrationFile>> {
123 if !dir.exists() {
124 return Ok(Vec::new());
125 }
126 let mut out = Vec::new();
127 for entry in fs::read_dir(dir)? {
128 let entry = entry?;
129 let path = entry.path();
130 if path.extension().and_then(|s| s.to_str()) != Some("sql") {
131 continue;
132 }
133 let stem = match path.file_stem().and_then(|s| s.to_str()) {
134 Some(s) => s,
135 None => continue,
136 };
137 let (ver_part, name_part) = match stem.split_once('_') {
138 Some(p) => p,
139 None => continue,
140 };
141 let version: i64 = match ver_part.parse() {
142 Ok(n) => n,
143 Err(_) => continue,
144 };
145 out.push(MigrationFile {
146 version,
147 name: name_part.to_string(),
148 path,
149 });
150 }
151 out.sort_by_key(|m| m.version);
152 Ok(out)
153}
154
155async fn ensure_tracking_table(db: &Db) -> Result<()> {
156 sqlx::query(
157 "CREATE TABLE IF NOT EXISTS rustio_migrations (
158 version BIGINT PRIMARY KEY,
159 name TEXT NOT NULL,
160 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
161 )",
162 )
163 .execute(db.pool())
164 .await?;
165 Ok(())
166}
167
168fn split_statements(sql: &str) -> Vec<String> {
172 let mut out = Vec::new();
173 let mut current = String::new();
174 let mut chars = sql.chars().peekable();
175 let mut in_string = false;
176 let mut in_dollar = false;
177 let mut dollar_tag = String::new();
178 let mut in_line_comment = false;
179 let mut in_block_comment = false;
180
181 while let Some(c) = chars.next() {
182 if in_line_comment {
183 current.push(c);
184 if c == '\n' {
185 in_line_comment = false;
186 }
187 continue;
188 }
189 if in_block_comment {
190 current.push(c);
191 if c == '*' && chars.peek() == Some(&'/') {
192 current.push(chars.next().unwrap());
193 in_block_comment = false;
194 }
195 continue;
196 }
197 if in_dollar {
198 current.push(c);
199 if c == '$' {
200 let rest: String = chars.clone().take(dollar_tag.len()).collect();
201 if rest == dollar_tag {
202 for _ in 0..dollar_tag.len() {
203 current.push(chars.next().unwrap());
204 }
205 in_dollar = false;
206 dollar_tag.clear();
207 }
208 }
209 continue;
210 }
211 if in_string {
212 current.push(c);
213 if c == '\'' {
214 if chars.peek() == Some(&'\'') {
215 current.push(chars.next().unwrap());
216 } else {
217 in_string = false;
218 }
219 }
220 continue;
221 }
222
223 match c {
224 '\'' => {
225 in_string = true;
226 current.push(c);
227 }
228 '-' if chars.peek() == Some(&'-') => {
229 in_line_comment = true;
230 current.push(c);
231 }
232 '/' if chars.peek() == Some(&'*') => {
233 in_block_comment = true;
234 current.push(c);
235 }
236 '$' => {
237 let mut tag = String::from("$");
238 let mut clone = chars.clone();
239 while let Some(&nc) = clone.peek() {
240 if nc == '$' {
241 tag.push('$');
242 break;
243 }
244 if nc.is_alphanumeric() || nc == '_' {
245 tag.push(nc);
246 clone.next();
247 } else {
248 break;
249 }
250 }
251 if tag.ends_with('$') && tag.len() >= 2 {
252 for _ in 1..tag.len() {
253 current.push(chars.next().unwrap());
254 }
255 current.insert(current.len() - tag.len() + 1, '$');
256 current.push('$');
257 dollar_tag = tag;
258 in_dollar = true;
259 } else {
260 current.push(c);
261 }
262 }
263 ';' => {
264 out.push(std::mem::take(&mut current));
265 }
266 other => current.push(other),
267 }
268 }
269
270 if !current.trim().is_empty() {
271 out.push(current);
272 }
273 out
274}
275
276fn slugify(name: &str) -> String {
277 name.chars()
278 .map(|c| {
279 if c.is_alphanumeric() {
280 c.to_ascii_lowercase()
281 } else {
282 '_'
283 }
284 })
285 .collect()
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn split_ignores_semicolon_in_string() {
294 let sql = "INSERT INTO t VALUES ('a;b'); SELECT 1;";
295 let parts = split_statements(sql);
296 assert_eq!(parts.len(), 2);
297 }
298
299 #[test]
300 fn split_ignores_line_comments() {
301 let sql = "SELECT 1; -- comment with ;\nSELECT 2;";
302 let parts = split_statements(sql);
303 assert_eq!(parts.len(), 2);
304 }
305
306 #[test]
307 fn slugify_lowercases_and_replaces() {
308 assert_eq!(slugify("Add Users Table!"), "add_users_table_");
309 }
310}