1use sqlx::{AnyPool, Row};
2
3pub type DbErr = sqlx::Error;
4
5pub struct SchemaManager<'a> {
6 pub pool: &'a AnyPool,
7}
8
9impl<'a> SchemaManager<'a> {
10 pub fn new(pool: &'a AnyPool) -> Self {
11 Self { pool }
12 }
13}
14
15pub struct Schema;
16
17impl Schema {
18 pub async fn create<F>(manager: &SchemaManager<'_>, table_name: &str, callback: F) -> Result<(), DbErr>
19 where
20 F: FnOnce(&mut Blueprint),
21 {
22 let mut blueprint = Blueprint::new(table_name);
23 callback(&mut blueprint);
24
25 let sqls = blueprint.to_create_sqls(manager.pool).await;
26 for sql in sqls {
27 sqlx::query::<sqlx::Any>(&sql).execute(manager.pool).await?;
28 }
29 Ok(())
30 }
31
32 pub async fn table<F>(manager: &SchemaManager<'_>, table_name: &str, callback: F) -> Result<(), DbErr>
33 where
34 F: FnOnce(&mut Blueprint),
35 {
36 let mut blueprint = Blueprint::new(table_name);
37 blueprint.auto_id = false;
38 blueprint.timestamps = false;
39 callback(&mut blueprint);
40
41 let sqls = blueprint.to_alter_sqls(manager.pool).await;
42 for sql in sqls {
43 sqlx::query::<sqlx::Any>(&sql).execute(manager.pool).await?;
44 }
45 Ok(())
46 }
47
48 pub async fn drop(manager: &SchemaManager<'_>, table_name: &str) -> Result<(), DbErr> {
49 let sql = format!("DROP TABLE IF EXISTS `{}`", table_name);
50 sqlx::query(&sql).execute(manager.pool).await?;
51 Ok(())
52 }
53}
54
55pub struct Column {
56 pub name: String,
57 pub col_type: String,
58 pub nullable: bool,
59 pub unique: bool,
60 pub primary_key: bool,
61 pub default_val: Option<String>,
62 pub is_indexed: bool,
63}
64
65pub struct ForeignKey {
66 pub from_col: String,
67 pub to_col: String,
68 pub to_table: String,
69 pub on_delete: Option<String>,
70 pub on_update: Option<String>,
71}
72
73pub struct Blueprint {
74 pub table_name: String,
75 pub columns: Vec<Column>,
76 pub foreign_keys: Vec<ForeignKey>,
77 pub drop_columns: Vec<String>,
78 pub auto_id: bool,
79 pub timestamps: bool,
80}
81
82impl Blueprint {
83 pub fn new(table_name: &str) -> Self {
84 Self {
85 table_name: table_name.to_string(),
86 columns: Vec::new(),
87 foreign_keys: Vec::new(),
88 drop_columns: Vec::new(),
89 auto_id: true,
90 timestamps: true,
91 }
92 }
93
94 pub fn no_id(&mut self) -> &mut Self {
95 self.auto_id = false;
96 self
97 }
98
99 pub fn no_timestamps(&mut self) -> &mut Self {
100 self.timestamps = false;
101 self
102 }
103
104 pub fn id(&mut self) -> &mut Self {
105 self.auto_id = true;
106 self
107 }
108
109 fn add_col(&mut self, name: &str, col_type: &str) -> &mut Column {
110 self.columns.push(Column {
111 name: name.to_string(),
112 col_type: col_type.to_string(),
113 nullable: false,
114 unique: false,
115 primary_key: false,
116 default_val: None,
117 is_indexed: false,
118 });
119 self.columns.last_mut().unwrap()
120 }
121
122 pub fn string(&mut self, name: &str) -> ColumnBuilder<'_> {
123 self.add_col(name, "VARCHAR(255)");
124 ColumnBuilder::new(self)
125 }
126
127 pub fn text(&mut self, name: &str) -> ColumnBuilder<'_> {
128 self.add_col(name, "TEXT");
129 ColumnBuilder::new(self)
130 }
131
132 pub fn integer(&mut self, name: &str) -> ColumnBuilder<'_> {
133 self.add_col(name, "INTEGER");
134 ColumnBuilder::new(self)
135 }
136
137 pub fn big_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
138 self.add_col(name, "BIGINT");
139 ColumnBuilder::new(self)
140 }
141
142 pub fn float(&mut self, name: &str) -> ColumnBuilder<'_> {
143 self.add_col(name, "FLOAT");
144 ColumnBuilder::new(self)
145 }
146
147 pub fn double(&mut self, name: &str) -> ColumnBuilder<'_> {
148 self.add_col(name, "DOUBLE");
149 ColumnBuilder::new(self)
150 }
151
152 pub fn decimal(&mut self, name: &str) -> ColumnBuilder<'_> {
153 self.add_col(name, "DECIMAL(10,2)");
154 ColumnBuilder::new(self)
155 }
156
157 pub fn char(&mut self, name: &str) -> ColumnBuilder<'_> {
158 self.add_col(name, "CHAR(255)");
159 ColumnBuilder::new(self)
160 }
161
162 pub fn boolean(&mut self, name: &str) -> ColumnBuilder<'_> {
163 self.add_col(name, "BOOLEAN");
164 ColumnBuilder::new(self)
165 }
166
167 pub fn date_time(&mut self, name: &str) -> ColumnBuilder<'_> {
168 self.add_col(name, "DATETIME");
169 ColumnBuilder::new(self)
170 }
171
172 pub fn timestamp(&mut self, name: &str) -> ColumnBuilder<'_> {
173 self.add_col(name, "TIMESTAMP");
174 ColumnBuilder::new(self)
175 }
176
177 pub fn uuid(&mut self, name: &str) -> ColumnBuilder<'_> {
178 self.add_col(name, "VARCHAR(36)");
179 ColumnBuilder::new(self)
180 }
181
182 pub fn json(&mut self, name: &str) -> ColumnBuilder<'_> {
183 self.add_col(name, "TEXT");
184 ColumnBuilder::new(self)
185 }
186
187 pub fn json_binary(&mut self, name: &str) -> ColumnBuilder<'_> {
188 self.add_col(name, "TEXT");
189 ColumnBuilder::new(self)
190 }
191
192 pub fn binary(&mut self, name: &str) -> ColumnBuilder<'_> {
193 self.add_col(name, "BLOB");
194 ColumnBuilder::new(self)
195 }
196
197 pub fn timestamps(&mut self) -> &mut Self {
198 self.timestamps = true;
199 self
200 }
201
202 pub fn foreign<'a>(&'a mut self, from_col: &str) -> ForeignKeyBuilder<'a> {
203 ForeignKeyBuilder::new(self, from_col)
204 }
205
206 pub fn drop_column(&mut self, name: &str) -> &mut Self {
207 self.drop_columns.push(name.to_string());
208 self
209 }
210
211 async fn to_alter_sqls(&self, pool: &AnyPool) -> Vec<String> {
212 let mut sqls = Vec::new();
213 let is_mysql = if let Ok(conn) = pool.acquire().await {
214 conn.backend_name() == "MySQL"
215 } else {
216 false
217 };
218
219 for col in &self.columns {
221 let mut col_type = col.col_type.clone();
222 if !is_mysql && (col_type == "DATETIME" || col_type == "TIMESTAMP") {
223 col_type = "TEXT".to_string();
224 }
225 let mut col_def = format!("`{}` {}", col.name, col_type);
226 if !col.nullable {
227 col_def.push_str(" NOT NULL");
228 }
229 if col.unique {
230 col_def.push_str(" UNIQUE");
231 }
232 if let Some(ref d) = col.default_val {
233 col_def.push_str(&format!(" DEFAULT {}", d));
234 }
235
236 let sql = format!("ALTER TABLE `{}` ADD COLUMN {}", self.table_name, col_def);
237 sqls.push(sql);
238 }
239
240 for col_name in &self.drop_columns {
242 let sql = format!("ALTER TABLE `{}` DROP COLUMN `{}`", self.table_name, col_name);
243 sqls.push(sql);
244 }
245
246 sqls
247 }
248
249 async fn to_create_sqls(&self, pool: &AnyPool) -> Vec<String> {
250 let mut sqls = Vec::new();
251 let is_mysql = if let Ok(conn) = pool.acquire().await {
252 conn.backend_name() == "MySQL"
253 } else {
254 false
255 };
256
257 let mut create_table = format!("CREATE TABLE IF NOT EXISTS `{}` (\n", self.table_name);
258 let mut col_parts = Vec::new();
259
260 if self.auto_id {
261 if is_mysql {
262 col_parts.push("`id` INT AUTO_INCREMENT PRIMARY KEY".to_string());
263 } else {
264 col_parts.push("`id` INTEGER PRIMARY KEY AUTOINCREMENT".to_string());
265 }
266 }
267
268 for col in &self.columns {
269 let mut col_type = col.col_type.clone();
270 if !is_mysql && (col_type == "DATETIME" || col_type == "TIMESTAMP") {
271 col_type = "TEXT".to_string();
272 }
273 let mut col_def = format!("`{}` {}", col.name, col_type);
274 if col.primary_key && !self.auto_id {
275 col_def.push_str(" PRIMARY KEY");
276 }
277 if !col.nullable {
278 col_def.push_str(" NOT NULL");
279 }
280 if col.unique {
281 col_def.push_str(" UNIQUE");
282 }
283 if let Some(ref d) = col.default_val {
284 col_def.push_str(&format!(" DEFAULT {}", d));
285 }
286 col_parts.push(col_def);
287 }
288
289 if self.timestamps {
290 if is_mysql {
291 col_parts.push("`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP".to_string());
292 col_parts.push("`updated_at` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP".to_string());
293 } else {
294 col_parts.push("`created_at` TEXT DEFAULT CURRENT_TIMESTAMP".to_string());
295 col_parts.push("`updated_at` TEXT DEFAULT CURRENT_TIMESTAMP".to_string());
296 }
297 }
298
299 for fk in &self.foreign_keys {
300 let mut fk_def = format!(
301 "FOREIGN KEY (`{}`) REFERENCES `{}` (`{}`)",
302 fk.from_col, fk.to_table, fk.to_col
303 );
304 if let Some(ref del) = fk.on_delete {
305 fk_def.push_str(&format!(" ON DELETE {}", del));
306 }
307 if let Some(ref upd) = fk.on_update {
308 fk_def.push_str(&format!(" ON UPDATE {}", upd));
309 }
310 col_parts.push(fk_def);
311 }
312
313 create_table.push_str(&col_parts.join(",\n"));
314 create_table.push_str("\n)");
315 sqls.push(create_table);
316
317 for col in &self.columns {
319 if col.is_indexed {
320 sqls.push(format!(
321 "CREATE INDEX IF NOT EXISTS `{}_{}_idx` ON `{}` (`{}`)",
322 self.table_name, col.name, self.table_name, col.name
323 ));
324 }
325 }
326
327 sqls
328 }
329}
330
331pub struct ColumnBuilder<'a> {
332 blueprint: &'a mut Blueprint,
333}
334
335impl<'a> ColumnBuilder<'a> {
336 pub fn new(blueprint: &'a mut Blueprint) -> Self {
337 Self { blueprint }
338 }
339
340 fn current(&mut self) -> &mut Column {
341 self.blueprint.columns.last_mut().unwrap()
342 }
343
344 pub fn not_null(mut self) -> Self {
345 self.current().nullable = false;
346 self
347 }
348
349 pub fn nullable(mut self) -> Self {
350 self.current().nullable = true;
351 self
352 }
353
354 pub fn unique(mut self) -> Self {
355 self.current().unique = true;
356 self
357 }
358
359 pub fn primary_key(mut self) -> Self {
360 self.current().primary_key = true;
361 self
362 }
363
364 pub fn default<T: ToString>(mut self, value: T) -> Self {
365 self.current().default_val = Some(value.to_string());
366 self
367 }
368
369 pub fn index(mut self) -> Self {
370 self.current().is_indexed = true;
371 self
372 }
373}
374
375pub struct ForeignKeyBuilder<'a> {
376 blueprint: &'a mut Blueprint,
377 from_col: String,
378 to_col: Option<String>,
379 to_table: Option<String>,
380 on_delete: Option<String>,
381 on_update: Option<String>,
382}
383
384impl<'a> ForeignKeyBuilder<'a> {
385 pub fn new(blueprint: &'a mut Blueprint, from_col: &str) -> Self {
386 Self {
387 blueprint,
388 from_col: from_col.to_string(),
389 to_col: None,
390 to_table: None,
391 on_delete: None,
392 on_update: None,
393 }
394 }
395
396 pub fn references(mut self, to_col: &str) -> Self {
397 self.to_col = Some(to_col.to_string());
398 self
399 }
400
401 pub fn on(mut self, to_table: &str) -> Self {
402 self.to_table = Some(to_table.to_string());
403 self
404 }
405
406 pub fn on_delete(mut self, action: &str) -> Self {
407 self.on_delete = Some(action.to_uppercase());
408 self
409 }
410
411 pub fn on_update(mut self, action: &str) -> Self {
412 self.on_update = Some(action.to_uppercase());
413 self
414 }
415}
416
417impl<'a> Drop for ForeignKeyBuilder<'a> {
418 fn drop(&mut self) {
419 if let (Some(to_table), Some(to_col)) = (&self.to_table, &self.to_col) {
420 self.blueprint.foreign_keys.push(ForeignKey {
421 from_col: self.from_col.clone(),
422 to_col: to_col.clone(),
423 to_table: to_table.clone(),
424 on_delete: self.on_delete.clone(),
425 on_update: self.on_update.clone(),
426 });
427 }
428 }
429}
430
431#[crate::async_trait]
436pub trait MigrationTrait: Send + Sync {
437 fn name(&self) -> &str;
438 async fn up<'a>(&self, manager: &'a SchemaManager<'a>) -> Result<(), DbErr>;
439 async fn down<'a>(&self, manager: &'a SchemaManager<'a>) -> Result<(), DbErr>;
440}
441
442#[crate::async_trait]
443pub trait MigratorTrait {
444 fn migrations() -> Vec<Box<dyn MigrationTrait>>;
445
446 async fn up(pool: &AnyPool, _steps: Option<u32>) -> Result<(), DbErr> {
447 let manager = SchemaManager::new(pool);
448
449 sqlx::query::<sqlx::Any>("CREATE TABLE IF NOT EXISTS migration_history (
451 version VARCHAR(255) PRIMARY KEY,
452 applied_at BIGINT NOT NULL
453 )").execute(pool).await?;
454
455 let rows = sqlx::query::<sqlx::Any>("SELECT version FROM migration_history").fetch_all(pool).await?;
457 let applied: std::collections::HashSet<String> = rows.into_iter()
458 .map(|r| r.get::<String, _>("version"))
459 .collect();
460
461 for migration in Self::migrations() {
463 let name = migration.name();
464 if !applied.contains(name) {
465 migration.up(&manager).await?;
466 let now = chrono::Utc::now().timestamp();
467 sqlx::query::<sqlx::Any>("INSERT INTO migration_history (version, applied_at) VALUES (?, ?)")
468 .bind(name)
469 .bind(now)
470 .execute(pool)
471 .await?;
472 println!("â
Migration applied: {}", name);
473 }
474 }
475
476 Ok(())
477 }
478
479 async fn down(pool: &AnyPool, _steps: Option<u32>) -> Result<(), DbErr> {
480 let manager = SchemaManager::new(pool);
481
482 let row_opt = sqlx::query::<sqlx::Any>("SELECT version FROM migration_history ORDER BY applied_at DESC LIMIT 1")
484 .fetch_optional(pool)
485 .await?;
486
487 if let Some(row) = row_opt {
488 let version = row.get::<String, _>("version");
489 for migration in Self::migrations() {
490 if migration.name() == version {
491 migration.down(&manager).await?;
492 sqlx::query::<sqlx::Any>("DELETE FROM migration_history WHERE version = ?")
493 .bind(&version)
494 .execute(pool)
495 .await?;
496 println!("âŦ
ī¸ Rollback applied: {}", version);
497 break;
498 }
499 }
500 } else {
501 println!("âšī¸ No migrations found to rollback.");
502 }
503
504 Ok(())
505 }
506
507 async fn fresh(pool: &AnyPool) -> Result<(), DbErr> {
508 let manager = SchemaManager::new(pool);
509
510 let applied_rows = sqlx::query::<sqlx::Any>("SELECT version FROM migration_history ORDER BY applied_at DESC")
512 .fetch_all(pool)
513 .await
514 .unwrap_or_default();
515
516 let migrations = Self::migrations();
517 for row in applied_rows {
518 let version = row.get::<String, _>("version");
519 if let Some(migration) = migrations.iter().find(|m| m.name() == version) {
520 let _ = migration.down(&manager).await;
521 }
522 }
523
524 let _ = sqlx::query::<sqlx::Any>("DROP TABLE IF EXISTS migration_history").execute(pool).await;
526
527 Self::up(pool, None).await?;
529 Ok(())
530 }
531}