1use crate::sql::{self, AnyPool};
2
3pub type DbErr = crate::sql::Error;
4
5pub struct SchemaManager<'a> {
6 pub pool: &'a AnyPool,
7}
8
9impl<'a> SchemaManager<'a> {
10 pub fn new(pool: &'a AnyPool) -> Self {
11 Self { pool }
12 }
13}
14
15pub struct Schema;
16
17impl Schema {
18 pub async fn create<F>(manager: &SchemaManager<'_>, table_name: &str, callback: F) -> Result<(), DbErr>
19 where
20 F: FnOnce(&mut Blueprint),
21 {
22 let mut blueprint = Blueprint::new(table_name);
23 callback(&mut blueprint);
24
25 let sqls = blueprint.to_create_sqls(manager.pool).await;
26 for sql in sqls {
27 sql::query::<sql::Any>(&sql).execute(manager.pool).await?;
28 }
29 Ok(())
30 }
31
32 pub async fn table<F>(manager: &SchemaManager<'_>, table_name: &str, callback: F) -> Result<(), DbErr>
33 where
34 F: FnOnce(&mut Blueprint),
35 {
36 let mut blueprint = Blueprint::new(table_name);
37 blueprint.auto_id = false;
38 blueprint.timestamps = false;
39 callback(&mut blueprint);
40
41 let sqls = blueprint.to_alter_sqls(manager.pool).await;
42 for sql in sqls {
43 sql::query::<sql::Any>(&sql).execute(manager.pool).await?;
44 }
45 Ok(())
46 }
47
48 pub async fn drop(manager: &SchemaManager<'_>, table_name: &str) -> Result<(), DbErr> {
49 let sql = format!("DROP TABLE IF EXISTS `{}`", table_name);
50 sql::query(&sql).execute(manager.pool).await?;
51 Ok(())
52 }
53}
54
55pub struct Column {
56 pub name: String,
57 pub col_type: String,
58 pub nullable: bool,
59 pub unique: bool,
60 pub primary_key: bool,
61 pub default_val: Option<String>,
62 pub is_indexed: bool,
63}
64
65pub struct ForeignKey {
66 pub from_col: String,
67 pub to_col: String,
68 pub to_table: String,
69 pub on_delete: Option<String>,
70 pub on_update: Option<String>,
71}
72
73pub struct Blueprint {
74 pub table_name: String,
75 pub columns: Vec<Column>,
76 pub foreign_keys: Vec<ForeignKey>,
77 pub drop_columns: Vec<String>,
78 pub auto_id: bool,
79 pub timestamps: bool,
80}
81
82impl Blueprint {
83 pub fn new(table_name: &str) -> Self {
84 Self {
85 table_name: table_name.to_string(),
86 columns: Vec::new(),
87 foreign_keys: Vec::new(),
88 drop_columns: Vec::new(),
89 auto_id: true,
90 timestamps: true,
91 }
92 }
93
94 pub fn no_id(&mut self) -> &mut Self {
95 self.auto_id = false;
96 self
97 }
98
99 pub fn no_timestamps(&mut self) -> &mut Self {
100 self.timestamps = false;
101 self
102 }
103
104 pub fn id(&mut self) -> &mut Self {
105 self.auto_id = true;
106 self
107 }
108
109 fn add_col(&mut self, name: &str, col_type: &str) -> &mut Column {
110 self.columns.push(Column {
111 name: name.to_string(),
112 col_type: col_type.to_string(),
113 nullable: false,
114 unique: false,
115 primary_key: false,
116 default_val: None,
117 is_indexed: false,
118 });
119 self.columns.last_mut().unwrap()
120 }
121
122 pub fn string(&mut self, name: &str) -> ColumnBuilder<'_> {
123 self.add_col(name, "VARCHAR(255)");
124 ColumnBuilder::new(self)
125 }
126
127
128 pub fn text(&mut self, name: &str) -> ColumnBuilder<'_> {
129 self.add_col(name, "TEXT");
130 ColumnBuilder::new(self)
131 }
132
133 pub fn long_text(&mut self, name: &str) -> ColumnBuilder<'_> {
134 self.add_col(name, "LONGTEXT");
135 ColumnBuilder::new(self)
136 }
137
138 pub fn medium_text(&mut self, name: &str) -> ColumnBuilder<'_> {
139 self.add_col(name, "MEDIUMTEXT");
140 ColumnBuilder::new(self)
141 }
142
143 pub fn tiny_text(&mut self, name: &str) -> ColumnBuilder<'_> {
144 self.add_col(name, "TINYTEXT");
145 ColumnBuilder::new(self)
146 }
147
148 pub fn integer(&mut self, name: &str) -> ColumnBuilder<'_> {
149 self.add_col(name, "INTEGER");
150 ColumnBuilder::new(self)
151 }
152
153 pub fn big_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
154 self.add_col(name, "BIGINT");
155 ColumnBuilder::new(self)
156 }
157
158 pub fn unsigned_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
159 self.add_col(name, "INT UNSIGNED");
160 ColumnBuilder::new(self)
161 }
162
163 pub fn unsigned_big_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
164 self.add_col(name, "BIGINT UNSIGNED");
165 ColumnBuilder::new(self)
166 }
167
168 pub fn unsigned_medium_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
169 self.add_col(name, "MEDIUMINT UNSIGNED");
170 ColumnBuilder::new(self)
171 }
172
173 pub fn unsigned_small_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
174 self.add_col(name, "SMALLINT UNSIGNED");
175 ColumnBuilder::new(self)
176 }
177
178 pub fn unsigned_tiny_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
179 self.add_col(name, "TINYINT UNSIGNED");
180 ColumnBuilder::new(self)
181 }
182
183 pub fn big_increments(&mut self, name: &str) -> ColumnBuilder<'_> {
184 self.add_col(name, "BIG_INCREMENTS");
185 self.columns.last_mut().unwrap().primary_key = true;
186 ColumnBuilder::new(self)
187 }
188
189 pub fn float(&mut self, name: &str) -> ColumnBuilder<'_> {
190 self.add_col(name, "FLOAT");
191 ColumnBuilder::new(self)
192 }
193
194 pub fn double(&mut self, name: &str) -> ColumnBuilder<'_> {
195 self.add_col(name, "DOUBLE");
196 ColumnBuilder::new(self)
197 }
198
199 pub fn decimal(&mut self, name: &str) -> ColumnBuilder<'_> {
200 self.add_col(name, "DECIMAL(10,2)");
201 ColumnBuilder::new(self)
202 }
203
204 pub fn char(&mut self, name: &str) -> ColumnBuilder<'_> {
205 self.add_col(name, "CHAR(255)");
206 ColumnBuilder::new(self)
207 }
208
209 pub fn boolean(&mut self, name: &str) -> ColumnBuilder<'_> {
210 self.add_col(name, "BOOLEAN");
211 ColumnBuilder::new(self)
212 }
213
214 pub fn date_time(&mut self, name: &str) -> ColumnBuilder<'_> {
215 self.add_col(name, "DATETIME");
216 ColumnBuilder::new(self)
217 }
218
219 pub fn date(&mut self, name: &str) -> ColumnBuilder<'_> {
220 self.add_col(name, "DATE");
221 ColumnBuilder::new(self)
222 }
223
224 pub fn time(&mut self, name: &str) -> ColumnBuilder<'_> {
225 self.add_col(name, "TIME");
226 ColumnBuilder::new(self)
227 }
228
229 pub fn timestamp(&mut self, name: &str) -> ColumnBuilder<'_> {
230 self.add_col(name, "TIMESTAMP");
231 ColumnBuilder::new(self)
232 }
233
234 pub fn soft_deletes(&mut self) -> ColumnBuilder<'_> {
235 self.add_col("deleted_at", "DATETIME");
236 self.columns.last_mut().unwrap().nullable = true;
237 ColumnBuilder::new(self)
238 }
239
240 pub fn uuid(&mut self, name: &str) -> ColumnBuilder<'_> {
241 self.add_col(name, "VARCHAR(36)");
242 ColumnBuilder::new(self)
243 }
244
245 pub fn json(&mut self, name: &str) -> ColumnBuilder<'_> {
246 self.add_col(name, "TEXT");
247 ColumnBuilder::new(self)
248 }
249
250 pub fn json_binary(&mut self, name: &str) -> ColumnBuilder<'_> {
251 self.add_col(name, "TEXT");
252 ColumnBuilder::new(self)
253 }
254
255 pub fn binary(&mut self, name: &str) -> ColumnBuilder<'_> {
256 self.add_col(name, "BLOB");
257 ColumnBuilder::new(self)
258 }
259
260 pub fn timestamps(&mut self) -> &mut Self {
261 self.timestamps = true;
262 self
263 }
264
265 pub fn foreign<'a>(&'a mut self, from_col: &str) -> ForeignKeyBuilder<'a> {
266 ForeignKeyBuilder::new(self, from_col)
267 }
268
269 pub fn drop_column(&mut self, name: &str) -> &mut Self {
270 self.drop_columns.push(name.to_string());
271 self
272 }
273
274 fn map_col_type(&self, col_type: &str, is_mysql: bool) -> String {
275 let mut mapped = col_type.to_string();
276 if mapped == "BIG_INCREMENTS" {
277 if is_mysql {
278 return "BIGINT AUTO_INCREMENT".to_string();
279 } else {
280 return "INTEGER".to_string();
281 }
282 }
283 if is_mysql {
284 if mapped == "DATETIME" || mapped == "TIMESTAMP" {
285 mapped = "VARCHAR(255)".to_string();
286 } else if mapped == "DATE" {
287 mapped = "VARCHAR(10)".to_string();
288 } else if mapped == "TIME" {
289 mapped = "VARCHAR(8)".to_string();
290 }
291 } else {
292 if mapped == "DATETIME" || mapped == "TIMESTAMP" || mapped == "DATE" || mapped == "TIME" || mapped.contains("TEXT") {
293 mapped = "TEXT".to_string();
294 } else if mapped.contains("UNSIGNED") {
295 mapped = "INTEGER".to_string();
296 }
297 }
298 mapped
299 }
300
301 async fn to_alter_sqls(&self, pool: &AnyPool) -> Vec<String> {
302 let mut sqls = Vec::new();
303 let is_mysql = if let Ok(conn) = pool.acquire().await {
304 conn.backend_name() == "MySQL"
305 } else {
306 false
307 };
308
309 for col in &self.columns {
311 let col_type = self.map_col_type(&col.col_type, is_mysql);
312 let mut col_def = format!("`{}` {}", col.name, col_type);
313 if !col.nullable {
314 col_def.push_str(" NOT NULL");
315 }
316 if col.unique {
317 col_def.push_str(" UNIQUE");
318 }
319 if let Some(ref d) = col.default_val {
320 col_def.push_str(&format!(" DEFAULT {}", d));
321 }
322
323 let sql = format!("ALTER TABLE `{}` ADD COLUMN {}", self.table_name, col_def);
324 sqls.push(sql);
325 }
326
327 for col_name in &self.drop_columns {
329 let sql = format!("ALTER TABLE `{}` DROP COLUMN `{}`", self.table_name, col_name);
330 sqls.push(sql);
331 }
332
333 sqls
334 }
335
336 async fn to_create_sqls(&self, pool: &AnyPool) -> Vec<String> {
337 let mut sqls = Vec::new();
338 let is_mysql = if let Ok(conn) = pool.acquire().await {
339 conn.backend_name() == "MySQL"
340 } else {
341 false
342 };
343
344 let mut create_table = format!("CREATE TABLE IF NOT EXISTS `{}` (\n", self.table_name);
345 let mut col_parts = Vec::new();
346
347 if self.auto_id {
348 if is_mysql {
349 col_parts.push("`id` INT AUTO_INCREMENT PRIMARY KEY".to_string());
350 } else {
351 col_parts.push("`id` INTEGER PRIMARY KEY AUTOINCREMENT".to_string());
352 }
353 }
354
355 for col in &self.columns {
356 let col_type = self.map_col_type(&col.col_type, is_mysql);
357 let mut col_def = format!("`{}` {}", col.name, col_type);
358 if col.primary_key && !self.auto_id {
359 col_def.push_str(" PRIMARY KEY");
360 }
361 if !col.nullable {
362 col_def.push_str(" NOT NULL");
363 }
364 if col.unique {
365 col_def.push_str(" UNIQUE");
366 }
367 if let Some(ref d) = col.default_val {
368 col_def.push_str(&format!(" DEFAULT {}", d));
369 }
370 col_parts.push(col_def);
371 }
372
373 if self.timestamps {
374 if is_mysql {
375 col_parts.push("`created_at` VARCHAR(255) NOT NULL DEFAULT ''".to_string());
376 col_parts.push("`updated_at` VARCHAR(255) NOT NULL DEFAULT ''".to_string());
377 } else {
378 col_parts.push("`created_at` TEXT DEFAULT CURRENT_TIMESTAMP".to_string());
379 col_parts.push("`updated_at` TEXT DEFAULT CURRENT_TIMESTAMP".to_string());
380 }
381 }
382
383 for fk in &self.foreign_keys {
384 let mut fk_def = format!(
385 "FOREIGN KEY (`{}`) REFERENCES `{}` (`{}`)",
386 fk.from_col, fk.to_table, fk.to_col
387 );
388 if let Some(ref del) = fk.on_delete {
389 fk_def.push_str(&format!(" ON DELETE {}", del));
390 }
391 if let Some(ref upd) = fk.on_update {
392 fk_def.push_str(&format!(" ON UPDATE {}", upd));
393 }
394 col_parts.push(fk_def);
395 }
396
397 create_table.push_str(&col_parts.join(",\n"));
398 create_table.push_str("\n)");
399 sqls.push(create_table);
400
401 for col in &self.columns {
403 if col.is_indexed {
404 sqls.push(format!(
405 "CREATE INDEX IF NOT EXISTS `{}_{}_idx` ON `{}` (`{}`)",
406 self.table_name, col.name, self.table_name, col.name
407 ));
408 }
409 }
410
411 sqls
412 }
413}
414
415pub struct ColumnBuilder<'a> {
416 blueprint: &'a mut Blueprint,
417}
418
419impl<'a> ColumnBuilder<'a> {
420 pub fn new(blueprint: &'a mut Blueprint) -> Self {
421 Self { blueprint }
422 }
423
424 fn current(&mut self) -> &mut Column {
425 self.blueprint.columns.last_mut().unwrap()
426 }
427
428 pub fn not_null(mut self) -> Self {
429 self.current().nullable = false;
430 self
431 }
432
433 pub fn nullable(mut self) -> Self {
434 self.current().nullable = true;
435 self
436 }
437
438 pub fn unique(mut self) -> Self {
439 self.current().unique = true;
440 self
441 }
442
443 pub fn primary_key(mut self) -> Self {
444 self.current().primary_key = true;
445 self
446 }
447
448 pub fn default<T: ToString>(mut self, value: T) -> Self {
449 self.current().default_val = Some(value.to_string());
450 self
451 }
452
453 pub fn index(mut self) -> Self {
454 self.current().is_indexed = true;
455 self
456 }
457}
458
459pub struct ForeignKeyBuilder<'a> {
460 blueprint: &'a mut Blueprint,
461 from_col: String,
462 to_col: Option<String>,
463 to_table: Option<String>,
464 on_delete: Option<String>,
465 on_update: Option<String>,
466}
467
468impl<'a> ForeignKeyBuilder<'a> {
469 pub fn new(blueprint: &'a mut Blueprint, from_col: &str) -> Self {
470 Self {
471 blueprint,
472 from_col: from_col.to_string(),
473 to_col: None,
474 to_table: None,
475 on_delete: None,
476 on_update: None,
477 }
478 }
479
480 pub fn references(mut self, to_col: &str) -> Self {
481 self.to_col = Some(to_col.to_string());
482 self
483 }
484
485 pub fn on(mut self, to_table: &str) -> Self {
486 self.to_table = Some(to_table.to_string());
487 self
488 }
489
490 pub fn on_delete(mut self, action: &str) -> Self {
491 self.on_delete = Some(action.to_uppercase());
492 self
493 }
494
495 pub fn on_update(mut self, action: &str) -> Self {
496 self.on_update = Some(action.to_uppercase());
497 self
498 }
499}
500
501impl<'a> Drop for ForeignKeyBuilder<'a> {
502 fn drop(&mut self) {
503 if let (Some(to_table), Some(to_col)) = (&self.to_table, &self.to_col) {
504 self.blueprint.foreign_keys.push(ForeignKey {
505 from_col: self.from_col.clone(),
506 to_col: to_col.clone(),
507 to_table: to_table.clone(),
508 on_delete: self.on_delete.clone(),
509 on_update: self.on_update.clone(),
510 });
511 }
512 }
513}
514
515#[crate::async_trait]
520pub trait MigrationTrait: Send + Sync {
521 fn name(&self) -> &str;
522 async fn up<'a>(&self, manager: &'a SchemaManager<'a>) -> Result<(), DbErr>;
523 async fn down<'a>(&self, manager: &'a SchemaManager<'a>) -> Result<(), DbErr>;
524}
525
526#[crate::async_trait]
527pub trait MigratorTrait {
528 fn migrations() -> Vec<Box<dyn MigrationTrait>>;
529
530 async fn up(pool: &AnyPool, _steps: Option<u32>) -> Result<(), DbErr> {
531 let manager = SchemaManager::new(pool);
532
533 sql::query::<sql::Any>("CREATE TABLE IF NOT EXISTS migration_history (
535 version VARCHAR(255) PRIMARY KEY,
536 applied_at BIGINT NOT NULL
537 )").execute(pool).await?;
538
539 let rows = sql::query::<sql::Any>("SELECT version FROM migration_history").fetch_all(pool).await?;
541 let applied: std::collections::HashSet<String> = rows.into_iter()
542 .map(|r| r.get::<String, _>("version"))
543 .collect();
544
545 for migration in Self::migrations() {
547 let name = migration.name();
548 if !applied.contains(name) {
549 migration.up(&manager).await?;
550 let now = crate::chrono::Utc::now().timestamp();
551 sql::query::<sql::Any>("INSERT INTO migration_history (version, applied_at) VALUES (?, ?)")
552 .bind(name)
553 .bind(now)
554 .execute(pool)
555 .await?;
556 println!("â
Migration applied: {}", name);
557 }
558 }
559
560 Ok(())
561 }
562
563 async fn down(pool: &AnyPool, _steps: Option<u32>) -> Result<(), DbErr> {
564 let manager = SchemaManager::new(pool);
565
566 let row_opt = sql::query::<sql::Any>("SELECT version FROM migration_history ORDER BY applied_at DESC LIMIT 1")
568 .fetch_optional(pool)
569 .await?;
570
571 if let Some(row) = row_opt {
572 let version = row.get::<String, _>("version");
573 for migration in Self::migrations() {
574 if migration.name() == version {
575 migration.down(&manager).await?;
576 sql::query::<sql::Any>("DELETE FROM migration_history WHERE version = ?")
577 .bind(&version)
578 .execute(pool)
579 .await?;
580 println!("âŦ
ī¸ Rollback applied: {}", version);
581 break;
582 }
583 }
584 } else {
585 println!("âšī¸ No migrations found to rollback.");
586 }
587
588 Ok(())
589 }
590
591 async fn fresh(pool: &AnyPool) -> Result<(), DbErr> {
592 let manager = SchemaManager::new(pool);
593
594 let applied_rows = sql::query::<sql::Any>("SELECT version FROM migration_history ORDER BY applied_at DESC")
596 .fetch_all(pool)
597 .await
598 .unwrap_or_default();
599
600 let migrations = Self::migrations();
601 for row in applied_rows {
602 let version = row.get::<String, _>("version");
603 if let Some(migration) = migrations.iter().find(|m| m.name() == version) {
604 let _ = migration.down(&manager).await;
605 }
606 }
607
608 let _ = sql::query::<sql::Any>("DROP TABLE IF EXISTS migration_history").execute(pool).await;
610
611 Self::up(pool, None).await?;
613 Ok(())
614 }
615}