1use sea_query::{Alias, ForeignKey, Index, Table, TableCreateStatement};
2
3use vespertide_core::{ColumnDef, ColumnType, ComplexColumnType, TableConstraint};
4
5use super::helpers::{
6 build_create_enum_type_sql, build_schema_statement, build_sea_column_def_with_table,
7 collect_sqlite_enum_check_clauses, to_sea_fk_action,
8};
9use super::types::{BuiltQuery, DatabaseBackend, RawSql};
10use crate::error::QueryError;
11
12pub(crate) fn build_create_table_for_backend(
13 backend: &DatabaseBackend,
14 table: &str,
15 columns: &[ColumnDef],
16 constraints: &[TableConstraint],
17) -> TableCreateStatement {
18 let mut stmt = Table::create().table(Alias::new(table)).to_owned();
19
20 let has_table_primary_key = constraints
21 .iter()
22 .any(|c| matches!(c, TableConstraint::PrimaryKey { .. }));
23
24 let auto_increment_columns: std::collections::HashSet<&str> = constraints
26 .iter()
27 .filter_map(|c| {
28 if let TableConstraint::PrimaryKey {
29 columns: pk_cols,
30 auto_increment: true,
31 } = c
32 {
33 Some(pk_cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())
34 } else {
35 None
36 }
37 })
38 .flatten()
39 .collect();
40
41 for column in columns {
43 let mut col = build_sea_column_def_with_table(backend, table, column);
44
45 if column.primary_key.is_some() && !has_table_primary_key {
47 col.primary_key();
48 }
49
50 if auto_increment_columns.contains(column.name.as_str()) {
52 if matches!(backend, DatabaseBackend::Sqlite) {
55 col.primary_key();
56 }
57 col.auto_increment();
58 }
59
60 stmt = stmt.col(col).to_owned();
65 }
66
67 for constraint in constraints {
69 match constraint {
70 TableConstraint::PrimaryKey {
71 columns: pk_cols,
72 auto_increment,
73 } => {
74 if matches!(backend, DatabaseBackend::Sqlite) && *auto_increment {
77 continue;
78 }
79 let mut pk_idx = Index::create();
81 for c in pk_cols {
82 pk_idx = pk_idx.col(Alias::new(c)).to_owned();
83 }
84 stmt = stmt.primary_key(&mut pk_idx).to_owned();
85 }
86 TableConstraint::Unique {
87 name,
88 columns: unique_cols,
89 } => {
90 if matches!(backend, DatabaseBackend::MySql) {
93 let index_name = super::helpers::build_unique_constraint_name(
95 table,
96 unique_cols,
97 name.as_deref(),
98 );
99 let mut idx = Index::create().name(&index_name).unique().to_owned();
100 for col in unique_cols {
101 idx = idx.col(Alias::new(col)).to_owned();
102 }
103 stmt = stmt.index(&mut idx).to_owned();
104 }
105 }
108 TableConstraint::ForeignKey {
109 name,
110 columns: fk_cols,
111 ref_table,
112 ref_columns,
113 on_delete,
114 on_update,
115 } => {
116 let fk_name =
118 super::helpers::build_foreign_key_name(table, fk_cols, name.as_deref());
119 let mut fk = ForeignKey::create().name(&fk_name).to_owned();
120 fk = fk.from_tbl(Alias::new(table)).to_owned();
121 for col in fk_cols {
122 fk = fk.from_col(Alias::new(col)).to_owned();
123 }
124 fk = fk.to_tbl(Alias::new(ref_table)).to_owned();
125 for col in ref_columns {
126 fk = fk.to_col(Alias::new(col)).to_owned();
127 }
128 if let Some(action) = on_delete {
129 fk = fk.on_delete(to_sea_fk_action(action)).to_owned();
130 }
131 if let Some(action) = on_update {
132 fk = fk.on_update(to_sea_fk_action(action)).to_owned();
133 }
134 stmt = stmt.foreign_key(&mut fk).to_owned();
135 }
136 TableConstraint::Check { name, expr } => {
137 let _ = (name, expr);
140 }
141 TableConstraint::Index { .. } => {
142 }
145 }
146 }
147
148 stmt
149}
150
151pub fn build_create_table(
152 backend: &DatabaseBackend,
153 table: &str,
154 columns: &[ColumnDef],
155 constraints: &[TableConstraint],
156) -> Result<Vec<BuiltQuery>, QueryError> {
157 let table_def = vespertide_core::TableDef {
160 description: None,
161 name: table.to_string(),
162 columns: columns.to_vec(),
163 constraints: constraints.to_vec(),
164 };
165 let normalized = table_def
166 .normalize()
167 .map_err(|e| QueryError::Other(format!("Failed to normalize table '{}': {}", table, e)))?;
168
169 let columns = &normalized.columns;
171 let constraints = &normalized.constraints;
172
173 let mut queries = Vec::new();
174
175 let mut created_enums = std::collections::HashSet::new();
178 for column in columns {
179 if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = &column.r#type
180 && created_enums.insert(name.clone())
181 && let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type)
182 {
183 queries.push(BuiltQuery::Raw(create_type_sql));
184 }
185 }
186
187 let (table_constraints, unique_constraints): (Vec<&TableConstraint>, Vec<&TableConstraint>) =
190 constraints
191 .iter()
192 .partition(|c| !matches!(c, TableConstraint::Unique { .. }));
193
194 let create_table_stmt = if matches!(backend, DatabaseBackend::MySql) {
198 build_create_table_for_backend(backend, table, columns, constraints)
199 } else {
200 let table_constraints_owned: Vec<TableConstraint> =
202 table_constraints.iter().cloned().cloned().collect();
203 build_create_table_for_backend(backend, table, columns, &table_constraints_owned)
204 };
205
206 if matches!(backend, DatabaseBackend::Sqlite) {
208 let enum_check_clauses = collect_sqlite_enum_check_clauses(table, columns);
209 if !enum_check_clauses.is_empty() {
210 let base_sql = build_schema_statement(&create_table_stmt, *backend);
212 let mut modified_sql = base_sql;
213 if let Some(pos) = modified_sql.rfind(')') {
214 let check_sql = enum_check_clauses.join(", ");
215 modified_sql.insert_str(pos, &format!(", {}", check_sql));
216 }
217 queries.push(BuiltQuery::Raw(RawSql::per_backend(
218 modified_sql.clone(),
219 modified_sql.clone(),
220 modified_sql,
221 )));
222 } else {
223 queries.push(BuiltQuery::CreateTable(Box::new(create_table_stmt)));
224 }
225 } else {
226 queries.push(BuiltQuery::CreateTable(Box::new(create_table_stmt)));
227 }
228
229 if matches!(backend, DatabaseBackend::Postgres | DatabaseBackend::Sqlite) {
231 for constraint in unique_constraints {
232 if let TableConstraint::Unique {
233 name,
234 columns: unique_cols,
235 } = constraint
236 {
237 let index_name = super::helpers::build_unique_constraint_name(
239 table,
240 unique_cols,
241 name.as_deref(),
242 );
243 let mut idx = Index::create()
244 .table(Alias::new(table))
245 .name(&index_name)
246 .unique()
247 .to_owned();
248 for col in unique_cols {
249 idx = idx.col(Alias::new(col)).to_owned();
250 }
251 queries.push(BuiltQuery::CreateIndex(Box::new(idx)));
252 }
253 }
254 }
255
256 for constraint in constraints {
258 if let TableConstraint::Index {
259 name,
260 columns: index_cols,
261 } = constraint
262 {
263 let index_name = super::helpers::build_index_name(table, index_cols, name.as_deref());
265 let mut idx = Index::create()
266 .table(Alias::new(table))
267 .name(&index_name)
268 .to_owned();
269 for col in index_cols {
270 idx = idx.col(Alias::new(col)).to_owned();
271 }
272 queries.push(BuiltQuery::CreateIndex(Box::new(idx)));
273 }
274 }
275
276 Ok(queries)
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use insta::{assert_snapshot, with_settings};
283 use rstest::rstest;
284 use vespertide_core::{ColumnType, EnumValues, SimpleColumnType};
285
286 fn col(name: &str, ty: ColumnType) -> ColumnDef {
287 ColumnDef {
288 name: name.to_string(),
289 r#type: ty,
290 nullable: true,
291 default: None,
292 comment: None,
293 primary_key: None,
294 unique: None,
295 index: None,
296 foreign_key: None,
297 }
298 }
299
300 #[rstest]
301 #[case::create_table_postgres(
302 "create_table_postgres",
303 DatabaseBackend::Postgres,
304 &["CREATE TABLE \"users\" ( \"id\" integer )"]
305 )]
306 #[case::create_table_mysql(
307 "create_table_mysql",
308 DatabaseBackend::MySql,
309 &["CREATE TABLE `users` ( `id` int )"]
310 )]
311 #[case::create_table_sqlite(
312 "create_table_sqlite",
313 DatabaseBackend::Sqlite,
314 &["CREATE TABLE \"users\" ( \"id\" integer )"]
315 )]
316 fn test_create_table(
317 #[case] title: &str,
318 #[case] backend: DatabaseBackend,
319 #[case] expected: &[&str],
320 ) {
321 let result = build_create_table(
322 &backend,
323 "users",
324 &[col("id", ColumnType::Simple(SimpleColumnType::Integer))],
325 &[],
326 )
327 .unwrap();
328 let sql = result
329 .iter()
330 .map(|q| q.build(backend))
331 .collect::<Vec<String>>()
332 .join("\n");
333 for exp in expected {
334 assert!(
335 sql.contains(exp),
336 "Expected SQL to contain '{}', got: {}",
337 exp,
338 sql
339 );
340 }
341
342 with_settings!({ snapshot_suffix => format!("create_table_{}", title) }, {
343 assert_snapshot!(sql);
344 });
345 }
346
347 #[rstest]
348 #[case::inline_unique_postgres(DatabaseBackend::Postgres)]
349 #[case::inline_unique_mysql(DatabaseBackend::MySql)]
350 #[case::inline_unique_sqlite(DatabaseBackend::Sqlite)]
351 fn test_create_table_with_inline_unique(#[case] backend: DatabaseBackend) {
352 use vespertide_core::schema::str_or_bool::StrOrBoolOrArray;
355
356 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text));
357 email_col.unique = Some(StrOrBoolOrArray::Bool(true));
358
359 let result = build_create_table(
360 &backend,
361 "users",
362 &[
363 col("id", ColumnType::Simple(SimpleColumnType::Integer)),
364 email_col,
365 ],
366 &[],
368 )
369 .unwrap();
370 let sql = result
371 .iter()
372 .map(|q| q.build(backend))
373 .collect::<Vec<String>>()
374 .join("\n");
375
376 assert!(
378 sql.contains("UNIQUE") || sql.to_uppercase().contains("UNIQUE"),
379 "Normalized unique constraint should be in SQL, but not found: {}",
380 sql
381 );
382 with_settings!({ snapshot_suffix => format!("create_table_with_inline_unique_{:?}", backend) }, {
383 assert_snapshot!(sql);
384 });
385 }
386
387 #[rstest]
388 #[case::table_level_unique_postgres(DatabaseBackend::Postgres)]
389 #[case::table_level_unique_mysql(DatabaseBackend::MySql)]
390 #[case::table_level_unique_sqlite(DatabaseBackend::Sqlite)]
391 fn test_create_table_with_table_level_unique(#[case] backend: DatabaseBackend) {
392 let result = build_create_table(
394 &backend,
395 "users",
396 &[
397 col("id", ColumnType::Simple(SimpleColumnType::Integer)),
398 col("email", ColumnType::Simple(SimpleColumnType::Text)),
399 ],
400 &[TableConstraint::Unique {
401 name: Some("uq_email".into()),
402 columns: vec!["email".into()],
403 }],
404 )
405 .unwrap();
406 let sql = result
407 .iter()
408 .map(|q| q.build(backend))
409 .collect::<Vec<String>>()
410 .join("\n");
411 assert!(sql.contains("CREATE TABLE"));
412 match backend {
414 DatabaseBackend::MySql => {
415 assert!(
416 sql.contains("UNIQUE"),
417 "MySQL should have UNIQUE in CREATE TABLE: {}",
418 sql
419 );
420 }
421 _ => {
422 assert!(
424 sql.contains("CREATE UNIQUE INDEX"),
425 "Postgres/SQLite should have CREATE UNIQUE INDEX: {}",
426 sql
427 );
428 }
429 }
430 with_settings!({ snapshot_suffix => format!("create_table_with_table_level_unique_{:?}", backend) }, {
431 assert_snapshot!(sql);
432 });
433 }
434
435 #[rstest]
436 #[case::table_level_unique_no_name_postgres(DatabaseBackend::Postgres)]
437 #[case::table_level_unique_no_name_mysql(DatabaseBackend::MySql)]
438 #[case::table_level_unique_no_name_sqlite(DatabaseBackend::Sqlite)]
439 fn test_create_table_with_table_level_unique_no_name(#[case] backend: DatabaseBackend) {
440 let result = build_create_table(
442 &backend,
443 "users",
444 &[
445 col("id", ColumnType::Simple(SimpleColumnType::Integer)),
446 col("email", ColumnType::Simple(SimpleColumnType::Text)),
447 ],
448 &[TableConstraint::Unique {
449 name: None,
450 columns: vec!["email".into()],
451 }],
452 )
453 .unwrap();
454 let sql = result
455 .iter()
456 .map(|q| q.build(backend))
457 .collect::<Vec<String>>()
458 .join("\n");
459 assert!(sql.contains("CREATE TABLE"));
460 match backend {
462 DatabaseBackend::MySql => {
463 assert!(
464 sql.contains("UNIQUE"),
465 "MySQL should have UNIQUE in CREATE TABLE: {}",
466 sql
467 );
468 }
469 _ => {
470 assert!(
472 sql.contains("CREATE UNIQUE INDEX"),
473 "Postgres/SQLite should have CREATE UNIQUE INDEX: {}",
474 sql
475 );
476 }
477 }
478 with_settings!({ snapshot_suffix => format!("create_table_with_table_level_unique_no_name_{:?}", backend) }, {
479 assert_snapshot!(sql);
480 });
481 }
482
483 #[rstest]
484 #[case::postgres(DatabaseBackend::Postgres)]
485 #[case::mysql(DatabaseBackend::MySql)]
486 #[case::sqlite(DatabaseBackend::Sqlite)]
487 fn test_create_table_with_enum_column(#[case] backend: DatabaseBackend) {
488 let columns = vec![
490 ColumnDef {
491 name: "id".into(),
492 r#type: ColumnType::Simple(SimpleColumnType::Integer),
493 nullable: false,
494 default: None,
495 comment: None,
496 primary_key: None,
497 unique: None,
498 index: None,
499 foreign_key: None,
500 },
501 ColumnDef {
502 name: "status".into(),
503 r#type: ColumnType::Complex(ComplexColumnType::Enum {
504 name: "user_status".into(),
505 values: EnumValues::String(vec![
506 "active".into(),
507 "inactive".into(),
508 "pending".into(),
509 ]),
510 }),
511 nullable: false,
512 default: Some("'active'".into()),
513 comment: None,
514 primary_key: None,
515 unique: None,
516 index: None,
517 foreign_key: None,
518 },
519 ];
520 let constraints = vec![TableConstraint::PrimaryKey {
521 auto_increment: false,
522 columns: vec!["id".into()],
523 }];
524
525 let result = build_create_table(&backend, "users", &columns, &constraints);
526 assert!(result.is_ok());
527 let queries = result.unwrap();
528 let sql = queries
529 .iter()
530 .map(|q| q.build(backend))
531 .collect::<Vec<String>>()
532 .join(";\n");
533
534 with_settings!({ snapshot_suffix => format!("create_table_with_enum_column_{:?}", backend) }, {
535 assert_snapshot!(sql);
536 });
537 }
538
539 #[rstest]
540 #[case::auto_increment_postgres(DatabaseBackend::Postgres)]
541 #[case::auto_increment_mysql(DatabaseBackend::MySql)]
542 #[case::auto_increment_sqlite(DatabaseBackend::Sqlite)]
543 fn test_create_table_with_auto_increment_primary_key(#[case] backend: DatabaseBackend) {
544 let columns = vec![ColumnDef {
546 name: "id".into(),
547 r#type: ColumnType::Simple(SimpleColumnType::Integer),
548 nullable: false,
549 default: None,
550 comment: None,
551 primary_key: None,
552 unique: None,
553 index: None,
554 foreign_key: None,
555 }];
556 let constraints = vec![TableConstraint::PrimaryKey {
557 auto_increment: true,
558 columns: vec!["id".into()],
559 }];
560
561 let result = build_create_table(&backend, "users", &columns, &constraints);
562 assert!(result.is_ok());
563 let queries = result.unwrap();
564 let sql = queries
565 .iter()
566 .map(|q| q.build(backend))
567 .collect::<Vec<String>>()
568 .join(";\n");
569
570 match backend {
572 DatabaseBackend::Postgres => {
573 assert!(
574 sql.contains("SERIAL") || sql.contains("serial"),
575 "PostgreSQL should use SERIAL for auto_increment, got: {}",
576 sql
577 );
578 }
579 DatabaseBackend::MySql => {
580 assert!(
581 sql.contains("AUTO_INCREMENT") || sql.contains("auto_increment"),
582 "MySQL should use AUTO_INCREMENT for auto_increment, got: {}",
583 sql
584 );
585 }
586 DatabaseBackend::Sqlite => {
587 assert!(
588 sql.contains("AUTOINCREMENT") || sql.contains("autoincrement"),
589 "SQLite should use AUTOINCREMENT for auto_increment, got: {}",
590 sql
591 );
592 }
593 }
594
595 with_settings!({ snapshot_suffix => format!("create_table_with_auto_increment_{:?}", backend) }, {
596 assert_snapshot!(sql);
597 });
598 }
599
600 #[rstest]
601 #[case::inline_auto_increment_postgres(DatabaseBackend::Postgres)]
602 #[case::inline_auto_increment_mysql(DatabaseBackend::MySql)]
603 #[case::inline_auto_increment_sqlite(DatabaseBackend::Sqlite)]
604 fn test_create_table_with_inline_auto_increment_primary_key(#[case] backend: DatabaseBackend) {
605 use vespertide_core::schema::primary_key::{PrimaryKeyDef, PrimaryKeySyntax};
607
608 let columns = vec![ColumnDef {
609 name: "id".into(),
610 r#type: ColumnType::Simple(SimpleColumnType::Integer),
611 nullable: false,
612 default: None,
613 comment: None,
614 primary_key: Some(PrimaryKeySyntax::Object(PrimaryKeyDef {
615 auto_increment: true,
616 })),
617 unique: None,
618 index: None,
619 foreign_key: None,
620 }];
621
622 let result = build_create_table(&backend, "users", &columns, &[]);
623 assert!(result.is_ok());
624 let queries = result.unwrap();
625 let sql = queries
626 .iter()
627 .map(|q| q.build(backend))
628 .collect::<Vec<String>>()
629 .join(";\n");
630
631 match backend {
633 DatabaseBackend::Postgres => {
634 assert!(
635 sql.contains("SERIAL") || sql.contains("serial"),
636 "PostgreSQL should use SERIAL for auto_increment, got: {}",
637 sql
638 );
639 }
640 DatabaseBackend::MySql => {
641 assert!(
642 sql.contains("AUTO_INCREMENT") || sql.contains("auto_increment"),
643 "MySQL should use AUTO_INCREMENT for auto_increment, got: {}",
644 sql
645 );
646 }
647 DatabaseBackend::Sqlite => {
648 assert!(
649 sql.contains("AUTOINCREMENT") || sql.contains("autoincrement"),
650 "SQLite should use AUTOINCREMENT for auto_increment, got: {}",
651 sql
652 );
653 }
654 }
655
656 with_settings!({ snapshot_suffix => format!("create_table_with_inline_auto_increment_{:?}", backend) }, {
657 assert_snapshot!(sql);
658 });
659 }
660
661 #[rstest]
664 #[case::timestamp_now_default_postgres(DatabaseBackend::Postgres)]
665 #[case::timestamp_now_default_mysql(DatabaseBackend::MySql)]
666 #[case::timestamp_now_default_sqlite(DatabaseBackend::Sqlite)]
667 fn test_create_table_with_timestamp_now_default(#[case] backend: DatabaseBackend) {
668 let columns = vec![
669 ColumnDef {
670 name: "id".into(),
671 r#type: ColumnType::Simple(SimpleColumnType::BigInt),
672 nullable: false,
673 default: None,
674 comment: None,
675 primary_key: None,
676 unique: None,
677 index: None,
678 foreign_key: None,
679 },
680 ColumnDef {
681 name: "created_at".into(),
682 r#type: ColumnType::Simple(SimpleColumnType::Timestamptz),
683 nullable: false,
684 default: Some("NOW()".into()), comment: None,
686 primary_key: None,
687 unique: None,
688 index: None,
689 foreign_key: None,
690 },
691 ];
692
693 let result = build_create_table(&backend, "events", &columns, &[]);
694 assert!(result.is_ok(), "build_create_table failed: {:?}", result);
695 let queries = result.unwrap();
696 let sql = queries
697 .iter()
698 .map(|q| q.build(backend))
699 .collect::<Vec<String>>()
700 .join("\n");
701
702 if matches!(backend, DatabaseBackend::Sqlite) {
704 assert!(
705 !sql.contains("NOW()"),
706 "SQLite should not contain NOW(), got: {}",
707 sql
708 );
709 assert!(
710 sql.contains("CURRENT_TIMESTAMP"),
711 "SQLite should use CURRENT_TIMESTAMP, got: {}",
712 sql
713 );
714 }
715
716 with_settings!({ snapshot_suffix => format!("create_table_with_timestamp_now_default_{:?}", backend) }, {
717 assert_snapshot!(sql);
718 });
719 }
720}