1use crate::introspect::{
8 ColumnInfo, DatabaseSchema, Dialect, ForeignKeyInfo, IndexInfo, ParsedSqlType, TableInfo,
9 UniqueConstraintInfo,
10};
11use sqlmodel_core::{FieldInfo, Model};
12
13pub trait ModelSchema: Model {
42 fn table_schema() -> TableInfo {
44 table_schema_from_model::<Self>()
45 }
46}
47
48impl<M: Model> ModelSchema for M {}
50
51pub fn table_schema_from_model<M: Model>() -> TableInfo {
57 table_schema_from_fields(M::TABLE_NAME, M::fields(), M::PRIMARY_KEY)
58}
59
60pub fn table_schema_from_fields(
66 table_name: &str,
67 fields: &[FieldInfo],
68 primary_key_cols: &[&str],
69) -> TableInfo {
70 let mut columns = Vec::with_capacity(fields.len());
71 let mut foreign_keys = Vec::new();
72 let mut unique_constraints = Vec::new();
73 let mut indexes = Vec::new();
74
75 for field in fields {
76 let sql_type = field.effective_sql_type();
78 columns.push(ColumnInfo {
79 name: field.column_name.to_string(),
80 sql_type: sql_type.clone(),
81 parsed_type: ParsedSqlType::parse(&sql_type),
82 nullable: field.nullable,
83 default: field.default.map(String::from),
84 primary_key: field.primary_key,
85 auto_increment: field.auto_increment,
86 comment: None,
87 });
88
89 if let Some(fk_ref) = field.foreign_key {
91 if let Some((ref_table, ref_col)) = parse_fk_reference(fk_ref) {
92 foreign_keys.push(ForeignKeyInfo {
93 name: Some(format!("fk_{}_{}", table_name, field.column_name)),
94 column: field.column_name.to_string(),
95 foreign_table: ref_table,
96 foreign_column: ref_col,
97 on_delete: field.on_delete.map(|a| a.as_sql().to_string()),
98 on_update: field.on_update.map(|a| a.as_sql().to_string()),
99 });
100 }
101 }
102
103 if field.unique && !field.primary_key {
105 unique_constraints.push(UniqueConstraintInfo {
106 name: Some(format!("uk_{}_{}", table_name, field.column_name)),
107 columns: vec![field.column_name.to_string()],
108 });
109 }
110
111 if let Some(idx_name) = field.index {
113 indexes.push(IndexInfo {
114 name: idx_name.to_string(),
115 columns: vec![field.column_name.to_string()],
116 unique: false,
117 index_type: None,
118 primary: false,
119 });
120 }
121 }
122
123 TableInfo {
124 name: table_name.to_string(),
125 columns,
126 primary_key: primary_key_cols.iter().map(|s| s.to_string()).collect(),
127 foreign_keys,
128 unique_constraints,
129 check_constraints: Vec::new(),
130 indexes,
131 comment: None,
132 }
133}
134
135fn parse_fk_reference(reference: &str) -> Option<(String, String)> {
137 let parts: Vec<&str> = reference.split('.').collect();
138 if parts.len() == 2 {
139 Some((parts[0].to_string(), parts[1].to_string()))
140 } else {
141 None
142 }
143}
144
145pub fn expected_schema<M: Model>(dialect: Dialect) -> DatabaseSchema {
157 let mut schema = DatabaseSchema::new(dialect);
158 let table_info = table_schema_from_model::<M>();
159 schema.tables.insert(table_info.name.clone(), table_info);
160 schema
161}
162
163pub trait ModelTuple {
167 fn all_table_schemas() -> Vec<TableInfo>;
169
170 fn database_schema(dialect: Dialect) -> DatabaseSchema {
172 let mut schema = DatabaseSchema::new(dialect);
173 for table in Self::all_table_schemas() {
174 schema.tables.insert(table.name.clone(), table);
175 }
176 schema
177 }
178}
179
180impl<A: Model> ModelTuple for (A,) {
182 fn all_table_schemas() -> Vec<TableInfo> {
183 vec![table_schema_from_model::<A>()]
184 }
185}
186
187impl<A: Model, B: Model> ModelTuple for (A, B) {
189 fn all_table_schemas() -> Vec<TableInfo> {
190 vec![
191 table_schema_from_model::<A>(),
192 table_schema_from_model::<B>(),
193 ]
194 }
195}
196
197impl<A: Model, B: Model, C: Model> ModelTuple for (A, B, C) {
199 fn all_table_schemas() -> Vec<TableInfo> {
200 vec![
201 table_schema_from_model::<A>(),
202 table_schema_from_model::<B>(),
203 table_schema_from_model::<C>(),
204 ]
205 }
206}
207
208impl<A: Model, B: Model, C: Model, D: Model> ModelTuple for (A, B, C, D) {
210 fn all_table_schemas() -> Vec<TableInfo> {
211 vec![
212 table_schema_from_model::<A>(),
213 table_schema_from_model::<B>(),
214 table_schema_from_model::<C>(),
215 table_schema_from_model::<D>(),
216 ]
217 }
218}
219
220impl<A: Model, B: Model, C: Model, D: Model, E: Model> ModelTuple for (A, B, C, D, E) {
222 fn all_table_schemas() -> Vec<TableInfo> {
223 vec![
224 table_schema_from_model::<A>(),
225 table_schema_from_model::<B>(),
226 table_schema_from_model::<C>(),
227 table_schema_from_model::<D>(),
228 table_schema_from_model::<E>(),
229 ]
230 }
231}
232
233impl<A: Model, B: Model, C: Model, D: Model, E: Model, F: Model> ModelTuple for (A, B, C, D, E, F) {
235 fn all_table_schemas() -> Vec<TableInfo> {
236 vec![
237 table_schema_from_model::<A>(),
238 table_schema_from_model::<B>(),
239 table_schema_from_model::<C>(),
240 table_schema_from_model::<D>(),
241 table_schema_from_model::<E>(),
242 table_schema_from_model::<F>(),
243 ]
244 }
245}
246
247impl<A: Model, B: Model, C: Model, D: Model, E: Model, F: Model, G: Model> ModelTuple
249 for (A, B, C, D, E, F, G)
250{
251 fn all_table_schemas() -> Vec<TableInfo> {
252 vec![
253 table_schema_from_model::<A>(),
254 table_schema_from_model::<B>(),
255 table_schema_from_model::<C>(),
256 table_schema_from_model::<D>(),
257 table_schema_from_model::<E>(),
258 table_schema_from_model::<F>(),
259 table_schema_from_model::<G>(),
260 ]
261 }
262}
263
264impl<A: Model, B: Model, C: Model, D: Model, E: Model, F: Model, G: Model, H: Model> ModelTuple
266 for (A, B, C, D, E, F, G, H)
267{
268 fn all_table_schemas() -> Vec<TableInfo> {
269 vec![
270 table_schema_from_model::<A>(),
271 table_schema_from_model::<B>(),
272 table_schema_from_model::<C>(),
273 table_schema_from_model::<D>(),
274 table_schema_from_model::<E>(),
275 table_schema_from_model::<F>(),
276 table_schema_from_model::<G>(),
277 table_schema_from_model::<H>(),
278 ]
279 }
280}
281
282pub fn normalize_sql_type(sql_type: &str, dialect: Dialect) -> String {
291 let upper = sql_type.to_uppercase();
292
293 match dialect {
294 Dialect::Sqlite => {
295 if upper.contains("INT") {
297 "INTEGER".to_string()
298 } else if upper.contains("CHAR") || upper.contains("TEXT") || upper.contains("CLOB") {
299 "TEXT".to_string()
300 } else if upper.contains("REAL") || upper.contains("FLOAT") || upper.contains("DOUB") {
301 "REAL".to_string()
302 } else if upper.contains("BLOB") || upper.is_empty() {
303 "BLOB".to_string()
304 } else {
305 upper
307 }
308 }
309 Dialect::Postgres => {
310 match upper.as_str() {
312 "INT" | "INT4" => "INTEGER".to_string(),
313 "INT8" => "BIGINT".to_string(),
314 "INT2" => "SMALLINT".to_string(),
315 "FLOAT4" => "REAL".to_string(),
316 "FLOAT8" => "DOUBLE PRECISION".to_string(),
317 "BOOL" => "BOOLEAN".to_string(),
318 "SERIAL" => "INTEGER".to_string(), "BIGSERIAL" => "BIGINT".to_string(),
320 "SMALLSERIAL" => "SMALLINT".to_string(),
321 _ => upper,
322 }
323 }
324 Dialect::Mysql => {
325 match upper.as_str() {
327 "INTEGER" => "INT".to_string(),
328 "BOOL" | "BOOLEAN" => "TINYINT".to_string(),
329 _ => upper,
330 }
331 }
332 }
333}
334
335#[cfg(test)]
340mod tests {
341 use super::*;
342 use sqlmodel_core::{ReferentialAction, Row, SqlType, Value};
343
344 struct TestHero;
346
347 impl Model for TestHero {
348 const TABLE_NAME: &'static str = "heroes";
349 const PRIMARY_KEY: &'static [&'static str] = &["id"];
350
351 fn fields() -> &'static [FieldInfo] {
352 static FIELDS: &[FieldInfo] = &[
353 FieldInfo::new("id", "id", SqlType::BigInt)
354 .nullable(true)
355 .primary_key(true)
356 .auto_increment(true),
357 FieldInfo::new("name", "name", SqlType::Text)
358 .sql_type_override("VARCHAR(100)")
359 .unique(true),
360 FieldInfo::new("age", "age", SqlType::Integer)
361 .nullable(true)
362 .index("idx_heroes_age"),
363 FieldInfo::new("team_id", "team_id", SqlType::BigInt)
364 .nullable(true)
365 .foreign_key("teams.id")
366 .on_delete(ReferentialAction::Cascade),
367 ];
368 FIELDS
369 }
370
371 fn to_row(&self) -> Vec<(&'static str, Value)> {
372 vec![]
373 }
374
375 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
376 Ok(TestHero)
377 }
378
379 fn primary_key_value(&self) -> Vec<Value> {
380 vec![]
381 }
382
383 fn is_new(&self) -> bool {
384 true
385 }
386 }
387
388 #[test]
389 fn test_model_schema_table_name() {
390 let schema = TestHero::table_schema();
391 assert_eq!(schema.name, "heroes");
392 }
393
394 #[test]
395 fn test_model_schema_columns() {
396 let schema = TestHero::table_schema();
397 assert_eq!(schema.columns.len(), 4);
398
399 let id_col = schema.column("id").unwrap();
400 assert_eq!(id_col.sql_type, "BIGINT");
401 assert!(id_col.primary_key);
402 assert!(id_col.auto_increment);
403
404 let name_col = schema.column("name").unwrap();
405 assert_eq!(name_col.sql_type, "VARCHAR(100)");
406 assert!(!name_col.nullable);
407 }
408
409 #[test]
410 fn test_model_schema_primary_key() {
411 let schema = TestHero::table_schema();
412 assert_eq!(schema.primary_key, vec!["id"]);
413 }
414
415 #[test]
416 fn test_model_schema_foreign_keys() {
417 let schema = TestHero::table_schema();
418 assert_eq!(schema.foreign_keys.len(), 1);
419
420 let fk = &schema.foreign_keys[0];
421 assert_eq!(fk.column, "team_id");
422 assert_eq!(fk.foreign_table, "teams");
423 assert_eq!(fk.foreign_column, "id");
424 assert_eq!(fk.on_delete, Some("CASCADE".to_string()));
425 }
426
427 #[test]
428 fn test_model_schema_unique_constraints() {
429 let schema = TestHero::table_schema();
430 assert_eq!(schema.unique_constraints.len(), 1);
431
432 let uk = &schema.unique_constraints[0];
433 assert_eq!(uk.columns, vec!["name"]);
434 }
435
436 #[test]
437 fn test_model_schema_indexes() {
438 let schema = TestHero::table_schema();
439 assert_eq!(schema.indexes.len(), 1);
440
441 let idx = &schema.indexes[0];
442 assert_eq!(idx.name, "idx_heroes_age");
443 assert_eq!(idx.columns, vec!["age"]);
444 assert!(!idx.unique);
445 }
446
447 #[test]
448 fn test_expected_schema() {
449 let schema = expected_schema::<TestHero>(Dialect::Sqlite);
450 assert_eq!(schema.dialect, Dialect::Sqlite);
451 assert!(schema.table("heroes").is_some());
452 }
453
454 #[test]
455 fn test_model_tuple_two() {
456 struct TestTeam;
457
458 impl Model for TestTeam {
459 const TABLE_NAME: &'static str = "teams";
460 const PRIMARY_KEY: &'static [&'static str] = &["id"];
461
462 fn fields() -> &'static [FieldInfo] {
463 static FIELDS: &[FieldInfo] = &[FieldInfo::new("id", "id", SqlType::BigInt)
464 .nullable(true)
465 .primary_key(true)
466 .auto_increment(true)];
467 FIELDS
468 }
469
470 fn to_row(&self) -> Vec<(&'static str, Value)> {
471 vec![]
472 }
473
474 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
475 Ok(TestTeam)
476 }
477
478 fn primary_key_value(&self) -> Vec<Value> {
479 vec![]
480 }
481
482 fn is_new(&self) -> bool {
483 true
484 }
485 }
486
487 let schema = <(TestHero, TestTeam)>::database_schema(Dialect::Postgres);
488 assert_eq!(schema.tables.len(), 2);
489 assert!(schema.table("heroes").is_some());
490 assert!(schema.table("teams").is_some());
491 }
492
493 #[test]
494 fn test_normalize_sql_type_sqlite() {
495 assert_eq!(normalize_sql_type("INTEGER", Dialect::Sqlite), "INTEGER");
496 assert_eq!(normalize_sql_type("INT", Dialect::Sqlite), "INTEGER");
497 assert_eq!(normalize_sql_type("BIGINT", Dialect::Sqlite), "INTEGER");
498 assert_eq!(normalize_sql_type("VARCHAR(100)", Dialect::Sqlite), "TEXT");
499 assert_eq!(normalize_sql_type("TEXT", Dialect::Sqlite), "TEXT");
500 assert_eq!(normalize_sql_type("REAL", Dialect::Sqlite), "REAL");
501 assert_eq!(normalize_sql_type("FLOAT", Dialect::Sqlite), "REAL");
502 }
503
504 #[test]
505 fn test_normalize_sql_type_postgres() {
506 assert_eq!(normalize_sql_type("INT", Dialect::Postgres), "INTEGER");
507 assert_eq!(normalize_sql_type("INT4", Dialect::Postgres), "INTEGER");
508 assert_eq!(normalize_sql_type("INT8", Dialect::Postgres), "BIGINT");
509 assert_eq!(
510 normalize_sql_type("FLOAT8", Dialect::Postgres),
511 "DOUBLE PRECISION"
512 );
513 assert_eq!(normalize_sql_type("BOOL", Dialect::Postgres), "BOOLEAN");
514 assert_eq!(normalize_sql_type("SERIAL", Dialect::Postgres), "INTEGER");
515 }
516
517 #[test]
518 fn test_normalize_sql_type_mysql() {
519 assert_eq!(normalize_sql_type("INTEGER", Dialect::Mysql), "INT");
520 assert_eq!(normalize_sql_type("BOOLEAN", Dialect::Mysql), "TINYINT");
521 assert_eq!(normalize_sql_type("BOOL", Dialect::Mysql), "TINYINT");
522 }
523
524 #[test]
525 fn test_parse_fk_reference() {
526 assert_eq!(
527 parse_fk_reference("users.id"),
528 Some(("users".to_string(), "id".to_string()))
529 );
530 assert_eq!(
531 parse_fk_reference("teams.team_id"),
532 Some(("teams".to_string(), "team_id".to_string()))
533 );
534 assert_eq!(parse_fk_reference("invalid"), None);
535 assert_eq!(parse_fk_reference("too.many.parts"), None);
536 }
537}