1use sqlx::{AnyPool, Row};
2
3pub type DbErr = sqlx::Error;
4
5pub struct SchemaManager<'a> {
6 pub pool: &'a AnyPool,
7}
8
9impl<'a> SchemaManager<'a> {
10 pub fn new(pool: &'a AnyPool) -> Self {
11 Self { pool }
12 }
13}
14
15pub struct Schema;
16
17impl Schema {
18 pub async fn create<F>(manager: &SchemaManager<'_>, table_name: &str, callback: F) -> Result<(), DbErr>
19 where
20 F: FnOnce(&mut Blueprint),
21 {
22 let mut blueprint = Blueprint::new(table_name);
23 callback(&mut blueprint);
24
25 let sqls = blueprint.to_create_sqls(manager.pool).await;
26 for sql in sqls {
27 sqlx::query::<sqlx::Any>(&sql).execute(manager.pool).await?;
28 }
29 Ok(())
30 }
31
32 pub async fn table<F>(manager: &SchemaManager<'_>, table_name: &str, callback: F) -> Result<(), DbErr>
33 where
34 F: FnOnce(&mut Blueprint),
35 {
36 let mut blueprint = Blueprint::new(table_name);
37 blueprint.auto_id = false;
38 blueprint.timestamps = false;
39 callback(&mut blueprint);
40
41 let sqls = blueprint.to_alter_sqls(manager.pool).await;
42 for sql in sqls {
43 sqlx::query::<sqlx::Any>(&sql).execute(manager.pool).await?;
44 }
45 Ok(())
46 }
47
48 pub async fn drop(manager: &SchemaManager<'_>, table_name: &str) -> Result<(), DbErr> {
49 let sql = format!("DROP TABLE IF EXISTS `{}`", table_name);
50 sqlx::query(&sql).execute(manager.pool).await?;
51 Ok(())
52 }
53}
54
55pub struct Column {
56 pub name: String,
57 pub col_type: String,
58 pub nullable: bool,
59 pub unique: bool,
60 pub primary_key: bool,
61 pub default_val: Option<String>,
62 pub is_indexed: bool,
63}
64
65pub struct ForeignKey {
66 pub from_col: String,
67 pub to_col: String,
68 pub to_table: String,
69 pub on_delete: Option<String>,
70 pub on_update: Option<String>,
71}
72
73pub struct Blueprint {
74 pub table_name: String,
75 pub columns: Vec<Column>,
76 pub foreign_keys: Vec<ForeignKey>,
77 pub drop_columns: Vec<String>,
78 pub auto_id: bool,
79 pub timestamps: bool,
80}
81
82impl Blueprint {
83 pub fn new(table_name: &str) -> Self {
84 Self {
85 table_name: table_name.to_string(),
86 columns: Vec::new(),
87 foreign_keys: Vec::new(),
88 drop_columns: Vec::new(),
89 auto_id: true,
90 timestamps: true,
91 }
92 }
93
94 pub fn no_id(&mut self) -> &mut Self {
95 self.auto_id = false;
96 self
97 }
98
99 pub fn no_timestamps(&mut self) -> &mut Self {
100 self.timestamps = false;
101 self
102 }
103
104 pub fn id(&mut self) -> &mut Self {
105 self.auto_id = true;
106 self
107 }
108
109 fn add_col(&mut self, name: &str, col_type: &str) -> &mut Column {
110 self.columns.push(Column {
111 name: name.to_string(),
112 col_type: col_type.to_string(),
113 nullable: false,
114 unique: false,
115 primary_key: false,
116 default_val: None,
117 is_indexed: false,
118 });
119 self.columns.last_mut().unwrap()
120 }
121
122 pub fn string(&mut self, name: &str) -> ColumnBuilder<'_> {
123 self.add_col(name, "VARCHAR(255)");
124 ColumnBuilder::new(self)
125 }
126
127
128 pub fn text(&mut self, name: &str) -> ColumnBuilder<'_> {
129 self.add_col(name, "TEXT");
130 ColumnBuilder::new(self)
131 }
132
133 pub fn long_text(&mut self, name: &str) -> ColumnBuilder<'_> {
134 self.add_col(name, "LONGTEXT");
135 ColumnBuilder::new(self)
136 }
137
138 pub fn medium_text(&mut self, name: &str) -> ColumnBuilder<'_> {
139 self.add_col(name, "MEDIUMTEXT");
140 ColumnBuilder::new(self)
141 }
142
143 pub fn tiny_text(&mut self, name: &str) -> ColumnBuilder<'_> {
144 self.add_col(name, "TINYTEXT");
145 ColumnBuilder::new(self)
146 }
147
148 pub fn integer(&mut self, name: &str) -> ColumnBuilder<'_> {
149 self.add_col(name, "INTEGER");
150 ColumnBuilder::new(self)
151 }
152
153 pub fn big_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
154 self.add_col(name, "BIGINT");
155 ColumnBuilder::new(self)
156 }
157
158 pub fn unsigned_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
159 self.add_col(name, "INT UNSIGNED");
160 ColumnBuilder::new(self)
161 }
162
163 pub fn unsigned_big_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
164 self.add_col(name, "BIGINT UNSIGNED");
165 ColumnBuilder::new(self)
166 }
167
168 pub fn unsigned_medium_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
169 self.add_col(name, "MEDIUMINT UNSIGNED");
170 ColumnBuilder::new(self)
171 }
172
173 pub fn unsigned_small_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
174 self.add_col(name, "SMALLINT UNSIGNED");
175 ColumnBuilder::new(self)
176 }
177
178 pub fn unsigned_tiny_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
179 self.add_col(name, "TINYINT UNSIGNED");
180 ColumnBuilder::new(self)
181 }
182
183 pub fn big_increments(&mut self, name: &str) -> ColumnBuilder<'_> {
184 self.add_col(name, "BIG_INCREMENTS");
185 self.columns.last_mut().unwrap().primary_key = true;
186 ColumnBuilder::new(self)
187 }
188
189 pub fn float(&mut self, name: &str) -> ColumnBuilder<'_> {
190 self.add_col(name, "FLOAT");
191 ColumnBuilder::new(self)
192 }
193
194 pub fn double(&mut self, name: &str) -> ColumnBuilder<'_> {
195 self.add_col(name, "DOUBLE");
196 ColumnBuilder::new(self)
197 }
198
199 pub fn decimal(&mut self, name: &str) -> ColumnBuilder<'_> {
200 self.add_col(name, "DECIMAL(10,2)");
201 ColumnBuilder::new(self)
202 }
203
204 pub fn char(&mut self, name: &str) -> ColumnBuilder<'_> {
205 self.add_col(name, "CHAR(255)");
206 ColumnBuilder::new(self)
207 }
208
209 pub fn boolean(&mut self, name: &str) -> ColumnBuilder<'_> {
210 self.add_col(name, "BOOLEAN");
211 ColumnBuilder::new(self)
212 }
213
214 pub fn date_time(&mut self, name: &str) -> ColumnBuilder<'_> {
215 self.add_col(name, "DATETIME");
216 ColumnBuilder::new(self)
217 }
218
219 pub fn date(&mut self, name: &str) -> ColumnBuilder<'_> {
220 self.add_col(name, "DATE");
221 ColumnBuilder::new(self)
222 }
223
224 pub fn time(&mut self, name: &str) -> ColumnBuilder<'_> {
225 self.add_col(name, "TIME");
226 ColumnBuilder::new(self)
227 }
228
229 pub fn timestamp(&mut self, name: &str) -> ColumnBuilder<'_> {
230 self.add_col(name, "TIMESTAMP");
231 ColumnBuilder::new(self)
232 }
233
234 pub fn soft_deletes(&mut self) -> ColumnBuilder<'_> {
235 self.add_col("deleted_at", "DATETIME");
236 self.columns.last_mut().unwrap().nullable = true;
237 ColumnBuilder::new(self)
238 }
239
240 pub fn uuid(&mut self, name: &str) -> ColumnBuilder<'_> {
241 self.add_col(name, "VARCHAR(36)");
242 ColumnBuilder::new(self)
243 }
244
245 pub fn json(&mut self, name: &str) -> ColumnBuilder<'_> {
246 self.add_col(name, "TEXT");
247 ColumnBuilder::new(self)
248 }
249
250 pub fn json_binary(&mut self, name: &str) -> ColumnBuilder<'_> {
251 self.add_col(name, "TEXT");
252 ColumnBuilder::new(self)
253 }
254
255 pub fn binary(&mut self, name: &str) -> ColumnBuilder<'_> {
256 self.add_col(name, "BLOB");
257 ColumnBuilder::new(self)
258 }
259
260 pub fn timestamps(&mut self) -> &mut Self {
261 self.timestamps = true;
262 self
263 }
264
265 pub fn foreign<'a>(&'a mut self, from_col: &str) -> ForeignKeyBuilder<'a> {
266 ForeignKeyBuilder::new(self, from_col)
267 }
268
269 pub fn drop_column(&mut self, name: &str) -> &mut Self {
270 self.drop_columns.push(name.to_string());
271 self
272 }
273
274 fn map_col_type(&self, col_type: &str, is_mysql: bool) -> String {
275 let mut mapped = col_type.to_string();
276 if mapped == "BIG_INCREMENTS" {
277 if is_mysql {
278 return "BIGINT AUTO_INCREMENT".to_string();
279 } else {
280 return "INTEGER".to_string();
281 }
282 }
283 if is_mysql {
284 if mapped == "DATETIME" || mapped == "TIMESTAMP" {
285 mapped = "VARCHAR(255)".to_string();
286 } else if mapped == "DATE" {
287 mapped = "VARCHAR(10)".to_string();
288 } else if mapped == "TIME" {
289 mapped = "VARCHAR(8)".to_string();
290 }
291 } else {
292 if mapped == "DATETIME" || mapped == "TIMESTAMP" || mapped == "DATE" || mapped == "TIME" {
293 mapped = "TEXT".to_string();
294 } else if mapped.contains("TEXT") {
295 mapped = "TEXT".to_string();
296 } else if mapped.contains("UNSIGNED") {
297 mapped = "INTEGER".to_string();
298 }
299 }
300 mapped
301 }
302
303 async fn to_alter_sqls(&self, pool: &AnyPool) -> Vec<String> {
304 let mut sqls = Vec::new();
305 let is_mysql = if let Ok(conn) = pool.acquire().await {
306 conn.backend_name() == "MySQL"
307 } else {
308 false
309 };
310
311 for col in &self.columns {
313 let col_type = self.map_col_type(&col.col_type, is_mysql);
314 let mut col_def = format!("`{}` {}", col.name, col_type);
315 if !col.nullable {
316 col_def.push_str(" NOT NULL");
317 }
318 if col.unique {
319 col_def.push_str(" UNIQUE");
320 }
321 if let Some(ref d) = col.default_val {
322 col_def.push_str(&format!(" DEFAULT {}", d));
323 }
324
325 let sql = format!("ALTER TABLE `{}` ADD COLUMN {}", self.table_name, col_def);
326 sqls.push(sql);
327 }
328
329 for col_name in &self.drop_columns {
331 let sql = format!("ALTER TABLE `{}` DROP COLUMN `{}`", self.table_name, col_name);
332 sqls.push(sql);
333 }
334
335 sqls
336 }
337
338 async fn to_create_sqls(&self, pool: &AnyPool) -> Vec<String> {
339 let mut sqls = Vec::new();
340 let is_mysql = if let Ok(conn) = pool.acquire().await {
341 conn.backend_name() == "MySQL"
342 } else {
343 false
344 };
345
346 let mut create_table = format!("CREATE TABLE IF NOT EXISTS `{}` (\n", self.table_name);
347 let mut col_parts = Vec::new();
348
349 if self.auto_id {
350 if is_mysql {
351 col_parts.push("`id` INT AUTO_INCREMENT PRIMARY KEY".to_string());
352 } else {
353 col_parts.push("`id` INTEGER PRIMARY KEY AUTOINCREMENT".to_string());
354 }
355 }
356
357 for col in &self.columns {
358 let col_type = self.map_col_type(&col.col_type, is_mysql);
359 let mut col_def = format!("`{}` {}", col.name, col_type);
360 if col.primary_key && !self.auto_id {
361 col_def.push_str(" PRIMARY KEY");
362 }
363 if !col.nullable {
364 col_def.push_str(" NOT NULL");
365 }
366 if col.unique {
367 col_def.push_str(" UNIQUE");
368 }
369 if let Some(ref d) = col.default_val {
370 col_def.push_str(&format!(" DEFAULT {}", d));
371 }
372 col_parts.push(col_def);
373 }
374
375 if self.timestamps {
376 if is_mysql {
377 col_parts.push("`created_at` VARCHAR(255) NOT NULL DEFAULT ''".to_string());
378 col_parts.push("`updated_at` VARCHAR(255) NOT NULL DEFAULT ''".to_string());
379 } else {
380 col_parts.push("`created_at` TEXT DEFAULT CURRENT_TIMESTAMP".to_string());
381 col_parts.push("`updated_at` TEXT DEFAULT CURRENT_TIMESTAMP".to_string());
382 }
383 }
384
385 for fk in &self.foreign_keys {
386 let mut fk_def = format!(
387 "FOREIGN KEY (`{}`) REFERENCES `{}` (`{}`)",
388 fk.from_col, fk.to_table, fk.to_col
389 );
390 if let Some(ref del) = fk.on_delete {
391 fk_def.push_str(&format!(" ON DELETE {}", del));
392 }
393 if let Some(ref upd) = fk.on_update {
394 fk_def.push_str(&format!(" ON UPDATE {}", upd));
395 }
396 col_parts.push(fk_def);
397 }
398
399 create_table.push_str(&col_parts.join(",\n"));
400 create_table.push_str("\n)");
401 sqls.push(create_table);
402
403 for col in &self.columns {
405 if col.is_indexed {
406 sqls.push(format!(
407 "CREATE INDEX IF NOT EXISTS `{}_{}_idx` ON `{}` (`{}`)",
408 self.table_name, col.name, self.table_name, col.name
409 ));
410 }
411 }
412
413 sqls
414 }
415}
416
417pub struct ColumnBuilder<'a> {
418 blueprint: &'a mut Blueprint,
419}
420
421impl<'a> ColumnBuilder<'a> {
422 pub fn new(blueprint: &'a mut Blueprint) -> Self {
423 Self { blueprint }
424 }
425
426 fn current(&mut self) -> &mut Column {
427 self.blueprint.columns.last_mut().unwrap()
428 }
429
430 pub fn not_null(mut self) -> Self {
431 self.current().nullable = false;
432 self
433 }
434
435 pub fn nullable(mut self) -> Self {
436 self.current().nullable = true;
437 self
438 }
439
440 pub fn unique(mut self) -> Self {
441 self.current().unique = true;
442 self
443 }
444
445 pub fn primary_key(mut self) -> Self {
446 self.current().primary_key = true;
447 self
448 }
449
450 pub fn default<T: ToString>(mut self, value: T) -> Self {
451 self.current().default_val = Some(value.to_string());
452 self
453 }
454
455 pub fn index(mut self) -> Self {
456 self.current().is_indexed = true;
457 self
458 }
459}
460
461pub struct ForeignKeyBuilder<'a> {
462 blueprint: &'a mut Blueprint,
463 from_col: String,
464 to_col: Option<String>,
465 to_table: Option<String>,
466 on_delete: Option<String>,
467 on_update: Option<String>,
468}
469
470impl<'a> ForeignKeyBuilder<'a> {
471 pub fn new(blueprint: &'a mut Blueprint, from_col: &str) -> Self {
472 Self {
473 blueprint,
474 from_col: from_col.to_string(),
475 to_col: None,
476 to_table: None,
477 on_delete: None,
478 on_update: None,
479 }
480 }
481
482 pub fn references(mut self, to_col: &str) -> Self {
483 self.to_col = Some(to_col.to_string());
484 self
485 }
486
487 pub fn on(mut self, to_table: &str) -> Self {
488 self.to_table = Some(to_table.to_string());
489 self
490 }
491
492 pub fn on_delete(mut self, action: &str) -> Self {
493 self.on_delete = Some(action.to_uppercase());
494 self
495 }
496
497 pub fn on_update(mut self, action: &str) -> Self {
498 self.on_update = Some(action.to_uppercase());
499 self
500 }
501}
502
503impl<'a> Drop for ForeignKeyBuilder<'a> {
504 fn drop(&mut self) {
505 if let (Some(to_table), Some(to_col)) = (&self.to_table, &self.to_col) {
506 self.blueprint.foreign_keys.push(ForeignKey {
507 from_col: self.from_col.clone(),
508 to_col: to_col.clone(),
509 to_table: to_table.clone(),
510 on_delete: self.on_delete.clone(),
511 on_update: self.on_update.clone(),
512 });
513 }
514 }
515}
516
517#[crate::async_trait]
522pub trait MigrationTrait: Send + Sync {
523 fn name(&self) -> &str;
524 async fn up<'a>(&self, manager: &'a SchemaManager<'a>) -> Result<(), DbErr>;
525 async fn down<'a>(&self, manager: &'a SchemaManager<'a>) -> Result<(), DbErr>;
526}
527
528#[crate::async_trait]
529pub trait MigratorTrait {
530 fn migrations() -> Vec<Box<dyn MigrationTrait>>;
531
532 async fn up(pool: &AnyPool, _steps: Option<u32>) -> Result<(), DbErr> {
533 let manager = SchemaManager::new(pool);
534
535 sqlx::query::<sqlx::Any>("CREATE TABLE IF NOT EXISTS migration_history (
537 version VARCHAR(255) PRIMARY KEY,
538 applied_at BIGINT NOT NULL
539 )").execute(pool).await?;
540
541 let rows = sqlx::query::<sqlx::Any>("SELECT version FROM migration_history").fetch_all(pool).await?;
543 let applied: std::collections::HashSet<String> = rows.into_iter()
544 .map(|r| r.get::<String, _>("version"))
545 .collect();
546
547 for migration in Self::migrations() {
549 let name = migration.name();
550 if !applied.contains(name) {
551 migration.up(&manager).await?;
552 let now = chrono::Utc::now().timestamp();
553 sqlx::query::<sqlx::Any>("INSERT INTO migration_history (version, applied_at) VALUES (?, ?)")
554 .bind(name)
555 .bind(now)
556 .execute(pool)
557 .await?;
558 println!("â
Migration applied: {}", name);
559 }
560 }
561
562 Ok(())
563 }
564
565 async fn down(pool: &AnyPool, _steps: Option<u32>) -> Result<(), DbErr> {
566 let manager = SchemaManager::new(pool);
567
568 let row_opt = sqlx::query::<sqlx::Any>("SELECT version FROM migration_history ORDER BY applied_at DESC LIMIT 1")
570 .fetch_optional(pool)
571 .await?;
572
573 if let Some(row) = row_opt {
574 let version = row.get::<String, _>("version");
575 for migration in Self::migrations() {
576 if migration.name() == version {
577 migration.down(&manager).await?;
578 sqlx::query::<sqlx::Any>("DELETE FROM migration_history WHERE version = ?")
579 .bind(&version)
580 .execute(pool)
581 .await?;
582 println!("âŦ
ī¸ Rollback applied: {}", version);
583 break;
584 }
585 }
586 } else {
587 println!("âšī¸ No migrations found to rollback.");
588 }
589
590 Ok(())
591 }
592
593 async fn fresh(pool: &AnyPool) -> Result<(), DbErr> {
594 let manager = SchemaManager::new(pool);
595
596 let applied_rows = sqlx::query::<sqlx::Any>("SELECT version FROM migration_history ORDER BY applied_at DESC")
598 .fetch_all(pool)
599 .await
600 .unwrap_or_default();
601
602 let migrations = Self::migrations();
603 for row in applied_rows {
604 let version = row.get::<String, _>("version");
605 if let Some(migration) = migrations.iter().find(|m| m.name() == version) {
606 let _ = migration.down(&manager).await;
607 }
608 }
609
610 let _ = sqlx::query::<sqlx::Any>("DROP TABLE IF EXISTS migration_history").execute(pool).await;
612
613 Self::up(pool, None).await?;
615 Ok(())
616 }
617}