1use sea_query::{Alias, ForeignKey, Index, Table, TableCreateStatement};
2
3use vespertide_core::{ColumnDef, ColumnType, ComplexColumnType, TableConstraint};
4
5use super::helpers::{
6 build_create_enum_type_sql, build_schema_statement, build_sea_column_def_with_table,
7 collect_sqlite_enum_check_clauses, to_sea_fk_action,
8};
9use super::types::{BuiltQuery, DatabaseBackend, RawSql};
10use crate::error::QueryError;
11
12pub(crate) fn build_create_table_for_backend(
13 backend: &DatabaseBackend,
14 table: &str,
15 columns: &[ColumnDef],
16 constraints: &[TableConstraint],
17) -> TableCreateStatement {
18 let mut stmt = Table::create().table(Alias::new(table)).to_owned();
19
20 let has_table_primary_key = constraints
21 .iter()
22 .any(|c| matches!(c, TableConstraint::PrimaryKey { .. }));
23
24 let auto_increment_columns: std::collections::HashSet<&str> = constraints
26 .iter()
27 .filter_map(|c| {
28 if let TableConstraint::PrimaryKey {
29 columns: pk_cols,
30 auto_increment: true,
31 } = c
32 {
33 Some(pk_cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())
34 } else {
35 None
36 }
37 })
38 .flatten()
39 .collect();
40
41 for column in columns {
43 let mut col = build_sea_column_def_with_table(backend, table, column);
44
45 if column.primary_key.is_some() && !has_table_primary_key {
47 col.primary_key();
48 }
49
50 if auto_increment_columns.contains(column.name.as_str())
55 && column.r#type.supports_auto_increment()
56 {
57 if matches!(backend, DatabaseBackend::Sqlite) {
60 col.primary_key();
61 }
62 col.auto_increment();
63 }
64
65 stmt = stmt.col(col).to_owned();
70 }
71
72 for constraint in constraints {
74 match constraint {
75 TableConstraint::PrimaryKey {
76 columns: pk_cols,
77 auto_increment,
78 } => {
79 if matches!(backend, DatabaseBackend::Sqlite)
83 && *auto_increment
84 && pk_cols.iter().all(|col_name| {
85 columns
86 .iter()
87 .find(|c| c.name == *col_name)
88 .is_some_and(|c| c.r#type.supports_auto_increment())
89 })
90 {
91 continue;
92 }
93 let mut pk_idx = Index::create();
95 for c in pk_cols {
96 pk_idx = pk_idx.col(Alias::new(c)).to_owned();
97 }
98 stmt = stmt.primary_key(&mut pk_idx).to_owned();
99 }
100 TableConstraint::Unique {
101 name,
102 columns: unique_cols,
103 } => {
104 if matches!(backend, DatabaseBackend::MySql) {
107 let index_name = super::helpers::build_unique_constraint_name(
109 table,
110 unique_cols,
111 name.as_deref(),
112 );
113 let mut idx = Index::create().name(&index_name).unique().to_owned();
114 for col in unique_cols {
115 idx = idx.col(Alias::new(col)).to_owned();
116 }
117 stmt = stmt.index(&mut idx).to_owned();
118 }
119 }
122 TableConstraint::ForeignKey {
123 name,
124 columns: fk_cols,
125 ref_table,
126 ref_columns,
127 on_delete,
128 on_update,
129 } => {
130 let fk_name =
132 super::helpers::build_foreign_key_name(table, fk_cols, name.as_deref());
133 let mut fk = ForeignKey::create().name(&fk_name).to_owned();
134 fk = fk.from_tbl(Alias::new(table)).to_owned();
135 for col in fk_cols {
136 fk = fk.from_col(Alias::new(col)).to_owned();
137 }
138 fk = fk.to_tbl(Alias::new(ref_table)).to_owned();
139 for col in ref_columns {
140 fk = fk.to_col(Alias::new(col)).to_owned();
141 }
142 if let Some(action) = on_delete {
143 fk = fk.on_delete(to_sea_fk_action(action)).to_owned();
144 }
145 if let Some(action) = on_update {
146 fk = fk.on_update(to_sea_fk_action(action)).to_owned();
147 }
148 stmt = stmt.foreign_key(&mut fk).to_owned();
149 }
150 TableConstraint::Check { name, expr } => {
151 let _ = (name, expr);
154 }
155 TableConstraint::Index { .. } => {
156 }
159 }
160 }
161
162 stmt
163}
164
165pub fn build_create_table(
166 backend: &DatabaseBackend,
167 table: &str,
168 columns: &[ColumnDef],
169 constraints: &[TableConstraint],
170) -> Result<Vec<BuiltQuery>, QueryError> {
171 let table_def = vespertide_core::TableDef {
174 description: None,
175 name: table.to_string(),
176 columns: columns.to_vec(),
177 constraints: constraints.to_vec(),
178 };
179 let normalized = table_def
180 .normalize()
181 .map_err(|e| QueryError::Other(format!("Failed to normalize table '{}': {}", table, e)))?;
182
183 let columns = &normalized.columns;
185 let constraints = &normalized.constraints;
186
187 let mut queries = Vec::new();
188
189 let mut created_enums = std::collections::HashSet::new();
192 for column in columns {
193 if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = &column.r#type
194 && created_enums.insert(name.clone())
195 && let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type)
196 {
197 queries.push(BuiltQuery::Raw(create_type_sql));
198 }
199 }
200
201 let (table_constraints, unique_constraints): (Vec<&TableConstraint>, Vec<&TableConstraint>) =
204 constraints
205 .iter()
206 .partition(|c| !matches!(c, TableConstraint::Unique { .. }));
207
208 let create_table_stmt = if matches!(backend, DatabaseBackend::MySql) {
212 build_create_table_for_backend(backend, table, columns, constraints)
213 } else {
214 let table_constraints_owned: Vec<TableConstraint> =
216 table_constraints.iter().cloned().cloned().collect();
217 build_create_table_for_backend(backend, table, columns, &table_constraints_owned)
218 };
219
220 if matches!(backend, DatabaseBackend::Sqlite) {
222 let enum_check_clauses = collect_sqlite_enum_check_clauses(table, columns);
223 if !enum_check_clauses.is_empty() {
224 let base_sql = build_schema_statement(&create_table_stmt, *backend);
226 let mut modified_sql = base_sql;
227 if let Some(pos) = modified_sql.rfind(')') {
228 let check_sql = enum_check_clauses.join(", ");
229 modified_sql.insert_str(pos, &format!(", {}", check_sql));
230 }
231 queries.push(BuiltQuery::Raw(RawSql::per_backend(
232 modified_sql.clone(),
233 modified_sql.clone(),
234 modified_sql,
235 )));
236 } else {
237 queries.push(BuiltQuery::CreateTable(Box::new(create_table_stmt)));
238 }
239 } else {
240 queries.push(BuiltQuery::CreateTable(Box::new(create_table_stmt)));
241 }
242
243 if matches!(backend, DatabaseBackend::Postgres | DatabaseBackend::Sqlite) {
245 for constraint in unique_constraints {
246 if let TableConstraint::Unique {
247 name,
248 columns: unique_cols,
249 } = constraint
250 {
251 let index_name = super::helpers::build_unique_constraint_name(
253 table,
254 unique_cols,
255 name.as_deref(),
256 );
257 let mut idx = Index::create()
258 .table(Alias::new(table))
259 .name(&index_name)
260 .unique()
261 .to_owned();
262 for col in unique_cols {
263 idx = idx.col(Alias::new(col)).to_owned();
264 }
265 queries.push(BuiltQuery::CreateIndex(Box::new(idx)));
266 }
267 }
268 }
269
270 for constraint in constraints {
272 if let TableConstraint::Index {
273 name,
274 columns: index_cols,
275 } = constraint
276 {
277 let index_name = super::helpers::build_index_name(table, index_cols, name.as_deref());
279 let mut idx = Index::create()
280 .table(Alias::new(table))
281 .name(&index_name)
282 .to_owned();
283 for col in index_cols {
284 idx = idx.col(Alias::new(col)).to_owned();
285 }
286 queries.push(BuiltQuery::CreateIndex(Box::new(idx)));
287 }
288 }
289
290 Ok(queries)
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use insta::{assert_snapshot, with_settings};
297 use rstest::rstest;
298 use vespertide_core::{ColumnType, EnumValues, SimpleColumnType};
299
300 fn col(name: &str, ty: ColumnType) -> ColumnDef {
301 ColumnDef {
302 name: name.to_string(),
303 r#type: ty,
304 nullable: true,
305 default: None,
306 comment: None,
307 primary_key: None,
308 unique: None,
309 index: None,
310 foreign_key: None,
311 }
312 }
313
314 #[rstest]
315 #[case::create_table_postgres(
316 "create_table_postgres",
317 DatabaseBackend::Postgres,
318 &["CREATE TABLE \"users\" ( \"id\" integer )"]
319 )]
320 #[case::create_table_mysql(
321 "create_table_mysql",
322 DatabaseBackend::MySql,
323 &["CREATE TABLE `users` ( `id` int )"]
324 )]
325 #[case::create_table_sqlite(
326 "create_table_sqlite",
327 DatabaseBackend::Sqlite,
328 &["CREATE TABLE \"users\" ( \"id\" integer )"]
329 )]
330 fn test_create_table(
331 #[case] title: &str,
332 #[case] backend: DatabaseBackend,
333 #[case] expected: &[&str],
334 ) {
335 let result = build_create_table(
336 &backend,
337 "users",
338 &[col("id", ColumnType::Simple(SimpleColumnType::Integer))],
339 &[],
340 )
341 .unwrap();
342 let sql = result
343 .iter()
344 .map(|q| q.build(backend))
345 .collect::<Vec<String>>()
346 .join("\n");
347 for exp in expected {
348 assert!(
349 sql.contains(exp),
350 "Expected SQL to contain '{}', got: {}",
351 exp,
352 sql
353 );
354 }
355
356 with_settings!({ snapshot_suffix => format!("create_table_{}", title) }, {
357 assert_snapshot!(sql);
358 });
359 }
360
361 #[rstest]
362 #[case::inline_unique_postgres(DatabaseBackend::Postgres)]
363 #[case::inline_unique_mysql(DatabaseBackend::MySql)]
364 #[case::inline_unique_sqlite(DatabaseBackend::Sqlite)]
365 fn test_create_table_with_inline_unique(#[case] backend: DatabaseBackend) {
366 use vespertide_core::schema::str_or_bool::StrOrBoolOrArray;
369
370 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text));
371 email_col.unique = Some(StrOrBoolOrArray::Bool(true));
372
373 let result = build_create_table(
374 &backend,
375 "users",
376 &[
377 col("id", ColumnType::Simple(SimpleColumnType::Integer)),
378 email_col,
379 ],
380 &[],
382 )
383 .unwrap();
384 let sql = result
385 .iter()
386 .map(|q| q.build(backend))
387 .collect::<Vec<String>>()
388 .join("\n");
389
390 assert!(
392 sql.contains("UNIQUE") || sql.to_uppercase().contains("UNIQUE"),
393 "Normalized unique constraint should be in SQL, but not found: {}",
394 sql
395 );
396 with_settings!({ snapshot_suffix => format!("create_table_with_inline_unique_{:?}", backend) }, {
397 assert_snapshot!(sql);
398 });
399 }
400
401 #[rstest]
402 #[case::table_level_unique_postgres(DatabaseBackend::Postgres)]
403 #[case::table_level_unique_mysql(DatabaseBackend::MySql)]
404 #[case::table_level_unique_sqlite(DatabaseBackend::Sqlite)]
405 fn test_create_table_with_table_level_unique(#[case] backend: DatabaseBackend) {
406 let result = build_create_table(
408 &backend,
409 "users",
410 &[
411 col("id", ColumnType::Simple(SimpleColumnType::Integer)),
412 col("email", ColumnType::Simple(SimpleColumnType::Text)),
413 ],
414 &[TableConstraint::Unique {
415 name: Some("uq_email".into()),
416 columns: vec!["email".into()],
417 }],
418 )
419 .unwrap();
420 let sql = result
421 .iter()
422 .map(|q| q.build(backend))
423 .collect::<Vec<String>>()
424 .join("\n");
425 assert!(sql.contains("CREATE TABLE"));
426 match backend {
428 DatabaseBackend::MySql => {
429 assert!(
430 sql.contains("UNIQUE"),
431 "MySQL should have UNIQUE in CREATE TABLE: {}",
432 sql
433 );
434 }
435 _ => {
436 assert!(
438 sql.contains("CREATE UNIQUE INDEX"),
439 "Postgres/SQLite should have CREATE UNIQUE INDEX: {}",
440 sql
441 );
442 }
443 }
444 with_settings!({ snapshot_suffix => format!("create_table_with_table_level_unique_{:?}", backend) }, {
445 assert_snapshot!(sql);
446 });
447 }
448
449 #[rstest]
450 #[case::table_level_unique_no_name_postgres(DatabaseBackend::Postgres)]
451 #[case::table_level_unique_no_name_mysql(DatabaseBackend::MySql)]
452 #[case::table_level_unique_no_name_sqlite(DatabaseBackend::Sqlite)]
453 fn test_create_table_with_table_level_unique_no_name(#[case] backend: DatabaseBackend) {
454 let result = build_create_table(
456 &backend,
457 "users",
458 &[
459 col("id", ColumnType::Simple(SimpleColumnType::Integer)),
460 col("email", ColumnType::Simple(SimpleColumnType::Text)),
461 ],
462 &[TableConstraint::Unique {
463 name: None,
464 columns: vec!["email".into()],
465 }],
466 )
467 .unwrap();
468 let sql = result
469 .iter()
470 .map(|q| q.build(backend))
471 .collect::<Vec<String>>()
472 .join("\n");
473 assert!(sql.contains("CREATE TABLE"));
474 match backend {
476 DatabaseBackend::MySql => {
477 assert!(
478 sql.contains("UNIQUE"),
479 "MySQL should have UNIQUE in CREATE TABLE: {}",
480 sql
481 );
482 }
483 _ => {
484 assert!(
486 sql.contains("CREATE UNIQUE INDEX"),
487 "Postgres/SQLite should have CREATE UNIQUE INDEX: {}",
488 sql
489 );
490 }
491 }
492 with_settings!({ snapshot_suffix => format!("create_table_with_table_level_unique_no_name_{:?}", backend) }, {
493 assert_snapshot!(sql);
494 });
495 }
496
497 #[rstest]
498 #[case::postgres(DatabaseBackend::Postgres)]
499 #[case::mysql(DatabaseBackend::MySql)]
500 #[case::sqlite(DatabaseBackend::Sqlite)]
501 fn test_create_table_with_enum_column(#[case] backend: DatabaseBackend) {
502 let columns = vec![
504 ColumnDef {
505 name: "id".into(),
506 r#type: ColumnType::Simple(SimpleColumnType::Integer),
507 nullable: false,
508 default: None,
509 comment: None,
510 primary_key: None,
511 unique: None,
512 index: None,
513 foreign_key: None,
514 },
515 ColumnDef {
516 name: "status".into(),
517 r#type: ColumnType::Complex(ComplexColumnType::Enum {
518 name: "user_status".into(),
519 values: EnumValues::String(vec![
520 "active".into(),
521 "inactive".into(),
522 "pending".into(),
523 ]),
524 }),
525 nullable: false,
526 default: Some("'active'".into()),
527 comment: None,
528 primary_key: None,
529 unique: None,
530 index: None,
531 foreign_key: None,
532 },
533 ];
534 let constraints = vec![TableConstraint::PrimaryKey {
535 auto_increment: false,
536 columns: vec!["id".into()],
537 }];
538
539 let result = build_create_table(&backend, "users", &columns, &constraints);
540 assert!(result.is_ok());
541 let queries = result.unwrap();
542 let sql = queries
543 .iter()
544 .map(|q| q.build(backend))
545 .collect::<Vec<String>>()
546 .join(";\n");
547
548 with_settings!({ snapshot_suffix => format!("create_table_with_enum_column_{:?}", backend) }, {
549 assert_snapshot!(sql);
550 });
551 }
552
553 #[rstest]
554 #[case::auto_increment_postgres(DatabaseBackend::Postgres)]
555 #[case::auto_increment_mysql(DatabaseBackend::MySql)]
556 #[case::auto_increment_sqlite(DatabaseBackend::Sqlite)]
557 fn test_create_table_with_auto_increment_primary_key(#[case] backend: DatabaseBackend) {
558 let columns = vec![ColumnDef {
560 name: "id".into(),
561 r#type: ColumnType::Simple(SimpleColumnType::Integer),
562 nullable: false,
563 default: None,
564 comment: None,
565 primary_key: None,
566 unique: None,
567 index: None,
568 foreign_key: None,
569 }];
570 let constraints = vec![TableConstraint::PrimaryKey {
571 auto_increment: true,
572 columns: vec!["id".into()],
573 }];
574
575 let result = build_create_table(&backend, "users", &columns, &constraints);
576 assert!(result.is_ok());
577 let queries = result.unwrap();
578 let sql = queries
579 .iter()
580 .map(|q| q.build(backend))
581 .collect::<Vec<String>>()
582 .join(";\n");
583
584 match backend {
586 DatabaseBackend::Postgres => {
587 assert!(
588 sql.contains("SERIAL") || sql.contains("serial"),
589 "PostgreSQL should use SERIAL for auto_increment, got: {}",
590 sql
591 );
592 }
593 DatabaseBackend::MySql => {
594 assert!(
595 sql.contains("AUTO_INCREMENT") || sql.contains("auto_increment"),
596 "MySQL should use AUTO_INCREMENT for auto_increment, got: {}",
597 sql
598 );
599 }
600 DatabaseBackend::Sqlite => {
601 assert!(
602 sql.contains("AUTOINCREMENT") || sql.contains("autoincrement"),
603 "SQLite should use AUTOINCREMENT for auto_increment, got: {}",
604 sql
605 );
606 }
607 }
608
609 with_settings!({ snapshot_suffix => format!("create_table_with_auto_increment_{:?}", backend) }, {
610 assert_snapshot!(sql);
611 });
612 }
613
614 #[rstest]
615 #[case::inline_auto_increment_postgres(DatabaseBackend::Postgres)]
616 #[case::inline_auto_increment_mysql(DatabaseBackend::MySql)]
617 #[case::inline_auto_increment_sqlite(DatabaseBackend::Sqlite)]
618 fn test_create_table_with_inline_auto_increment_primary_key(#[case] backend: DatabaseBackend) {
619 use vespertide_core::schema::primary_key::{PrimaryKeyDef, PrimaryKeySyntax};
621
622 let columns = vec![ColumnDef {
623 name: "id".into(),
624 r#type: ColumnType::Simple(SimpleColumnType::Integer),
625 nullable: false,
626 default: None,
627 comment: None,
628 primary_key: Some(PrimaryKeySyntax::Object(PrimaryKeyDef {
629 auto_increment: true,
630 })),
631 unique: None,
632 index: None,
633 foreign_key: None,
634 }];
635
636 let result = build_create_table(&backend, "users", &columns, &[]);
637 assert!(result.is_ok());
638 let queries = result.unwrap();
639 let sql = queries
640 .iter()
641 .map(|q| q.build(backend))
642 .collect::<Vec<String>>()
643 .join(";\n");
644
645 match backend {
647 DatabaseBackend::Postgres => {
648 assert!(
649 sql.contains("SERIAL") || sql.contains("serial"),
650 "PostgreSQL should use SERIAL for auto_increment, got: {}",
651 sql
652 );
653 }
654 DatabaseBackend::MySql => {
655 assert!(
656 sql.contains("AUTO_INCREMENT") || sql.contains("auto_increment"),
657 "MySQL should use AUTO_INCREMENT for auto_increment, got: {}",
658 sql
659 );
660 }
661 DatabaseBackend::Sqlite => {
662 assert!(
663 sql.contains("AUTOINCREMENT") || sql.contains("autoincrement"),
664 "SQLite should use AUTOINCREMENT for auto_increment, got: {}",
665 sql
666 );
667 }
668 }
669
670 with_settings!({ snapshot_suffix => format!("create_table_with_inline_auto_increment_{:?}", backend) }, {
671 assert_snapshot!(sql);
672 });
673 }
674
675 #[rstest]
678 #[case::timestamp_now_default_postgres(DatabaseBackend::Postgres)]
679 #[case::timestamp_now_default_mysql(DatabaseBackend::MySql)]
680 #[case::timestamp_now_default_sqlite(DatabaseBackend::Sqlite)]
681 fn test_create_table_with_timestamp_now_default(#[case] backend: DatabaseBackend) {
682 let columns = vec![
683 ColumnDef {
684 name: "id".into(),
685 r#type: ColumnType::Simple(SimpleColumnType::BigInt),
686 nullable: false,
687 default: None,
688 comment: None,
689 primary_key: None,
690 unique: None,
691 index: None,
692 foreign_key: None,
693 },
694 ColumnDef {
695 name: "created_at".into(),
696 r#type: ColumnType::Simple(SimpleColumnType::Timestamptz),
697 nullable: false,
698 default: Some("NOW()".into()), comment: None,
700 primary_key: None,
701 unique: None,
702 index: None,
703 foreign_key: None,
704 },
705 ];
706
707 let result = build_create_table(&backend, "events", &columns, &[]);
708 assert!(result.is_ok(), "build_create_table failed: {:?}", result);
709 let queries = result.unwrap();
710 let sql = queries
711 .iter()
712 .map(|q| q.build(backend))
713 .collect::<Vec<String>>()
714 .join("\n");
715
716 if matches!(backend, DatabaseBackend::Sqlite) {
718 assert!(
719 !sql.contains("NOW()"),
720 "SQLite should not contain NOW(), got: {}",
721 sql
722 );
723 assert!(
724 sql.contains("CURRENT_TIMESTAMP"),
725 "SQLite should use CURRENT_TIMESTAMP, got: {}",
726 sql
727 );
728 }
729
730 with_settings!({ snapshot_suffix => format!("create_table_with_timestamp_now_default_{:?}", backend) }, {
731 assert_snapshot!(sql);
732 });
733 }
734}