Skip to main content

sqlmodel_schema/
expected.rs

1//! Expected schema extraction from Model definitions.
2//!
3//! This module provides utilities to extract the "expected" database schema
4//! from Rust Model definitions, which can then be compared against the
5//! actual database schema obtained via introspection.
6
7use crate::introspect::{
8    ColumnInfo, DatabaseSchema, Dialect, ForeignKeyInfo, IndexInfo, ParsedSqlType, TableInfo,
9    UniqueConstraintInfo,
10};
11use sqlmodel_core::{FieldInfo, Model};
12
13// ============================================================================
14// Extension Trait for Model
15// ============================================================================
16
17/// Extension trait that adds schema extraction to Model types.
18///
19/// This trait is automatically implemented for all types that implement `Model`.
20/// It provides the `table_schema()` method to extract the expected table schema
21/// from the Model's metadata.
22///
23/// # Example
24///
25/// ```ignore
26/// use sqlmodel::Model;
27/// use sqlmodel_schema::expected::ModelSchema;
28///
29/// #[derive(Model)]
30/// #[sqlmodel(table = "heroes")]
31/// struct Hero {
32///     #[sqlmodel(primary_key, auto_increment)]
33///     id: Option<i64>,
34///     name: String,
35/// }
36///
37/// // Extract the expected schema
38/// let schema = Hero::table_schema();
39/// assert_eq!(schema.name, "heroes");
40/// ```
41pub trait ModelSchema: Model {
42    /// Get the expected table schema for this model.
43    fn table_schema() -> TableInfo {
44        table_schema_from_model::<Self>()
45    }
46}
47
48// Blanket implementation for all Model types
49impl<M: Model> ModelSchema for M {}
50
51// ============================================================================
52// Schema Extraction Functions
53// ============================================================================
54
55/// Extract a TableInfo from a Model type.
56pub fn table_schema_from_model<M: Model>() -> TableInfo {
57    table_schema_from_fields(M::TABLE_NAME, M::fields(), M::PRIMARY_KEY)
58}
59
60/// Convert field metadata to a TableInfo.
61///
62/// This is the core conversion function that transforms the compile-time
63/// FieldInfo array into a runtime TableInfo structure compatible with
64/// database introspection.
65pub fn table_schema_from_fields(
66    table_name: &str,
67    fields: &[FieldInfo],
68    primary_key_cols: &[&str],
69) -> TableInfo {
70    let mut columns = Vec::with_capacity(fields.len());
71    let mut foreign_keys = Vec::new();
72    let mut unique_constraints = Vec::new();
73    let mut indexes = Vec::new();
74
75    for field in fields {
76        // Convert FieldInfo to ColumnInfo
77        let sql_type = field.effective_sql_type();
78        columns.push(ColumnInfo {
79            name: field.column_name.to_string(),
80            sql_type: sql_type.clone(),
81            parsed_type: ParsedSqlType::parse(&sql_type),
82            nullable: field.nullable,
83            default: field.default.map(String::from),
84            primary_key: field.primary_key,
85            auto_increment: field.auto_increment,
86            comment: None,
87        });
88
89        // Extract foreign key if present
90        if let Some(fk_ref) = field.foreign_key {
91            if let Some((ref_table, ref_col)) = parse_fk_reference(fk_ref) {
92                foreign_keys.push(ForeignKeyInfo {
93                    name: Some(format!("fk_{}_{}", table_name, field.column_name)),
94                    column: field.column_name.to_string(),
95                    foreign_table: ref_table,
96                    foreign_column: ref_col,
97                    on_delete: field.on_delete.map(|a| a.as_sql().to_string()),
98                    on_update: field.on_update.map(|a| a.as_sql().to_string()),
99                });
100            }
101        }
102
103        // Extract unique constraint if present (and not part of PK)
104        if field.unique && !field.primary_key {
105            unique_constraints.push(UniqueConstraintInfo {
106                name: Some(format!("uk_{}_{}", table_name, field.column_name)),
107                columns: vec![field.column_name.to_string()],
108            });
109        }
110
111        // Extract index if present
112        if let Some(idx_name) = field.index {
113            indexes.push(IndexInfo {
114                name: idx_name.to_string(),
115                columns: vec![field.column_name.to_string()],
116                unique: false,
117                index_type: None,
118                primary: false,
119            });
120        }
121    }
122
123    TableInfo {
124        name: table_name.to_string(),
125        columns,
126        primary_key: primary_key_cols.iter().map(|s| s.to_string()).collect(),
127        foreign_keys,
128        unique_constraints,
129        check_constraints: Vec::new(),
130        indexes,
131        comment: None,
132    }
133}
134
135/// Parse a foreign key reference string (e.g., "users.id") into (table, column).
136fn parse_fk_reference(reference: &str) -> Option<(String, String)> {
137    let parts: Vec<&str> = reference.split('.').collect();
138    if parts.len() == 2 {
139        Some((parts[0].to_string(), parts[1].to_string()))
140    } else {
141        None
142    }
143}
144
145// ============================================================================
146// Schema Aggregation
147// ============================================================================
148
149/// Build a DatabaseSchema from a single Model.
150///
151/// # Example
152///
153/// ```ignore
154/// let schema = expected_schema::<Hero>(Dialect::Sqlite);
155/// ```
156pub fn expected_schema<M: Model>(dialect: Dialect) -> DatabaseSchema {
157    let mut schema = DatabaseSchema::new(dialect);
158    let table_info = table_schema_from_model::<M>();
159    schema.tables.insert(table_info.name.clone(), table_info);
160    schema
161}
162
163/// Trait for tuples of Models to aggregate their schemas.
164///
165/// This allows building a complete expected schema from multiple models.
166pub trait ModelTuple {
167    /// Get all table schemas from this tuple of models.
168    fn all_table_schemas() -> Vec<TableInfo>;
169
170    /// Build a complete database schema from all models in this tuple.
171    fn database_schema(dialect: Dialect) -> DatabaseSchema {
172        let mut schema = DatabaseSchema::new(dialect);
173        for table in Self::all_table_schemas() {
174            schema.tables.insert(table.name.clone(), table);
175        }
176        schema
177    }
178}
179
180// Implement for single model
181impl<A: Model> ModelTuple for (A,) {
182    fn all_table_schemas() -> Vec<TableInfo> {
183        vec![table_schema_from_model::<A>()]
184    }
185}
186
187// Implement for 2-tuple
188impl<A: Model, B: Model> ModelTuple for (A, B) {
189    fn all_table_schemas() -> Vec<TableInfo> {
190        vec![
191            table_schema_from_model::<A>(),
192            table_schema_from_model::<B>(),
193        ]
194    }
195}
196
197// Implement for 3-tuple
198impl<A: Model, B: Model, C: Model> ModelTuple for (A, B, C) {
199    fn all_table_schemas() -> Vec<TableInfo> {
200        vec![
201            table_schema_from_model::<A>(),
202            table_schema_from_model::<B>(),
203            table_schema_from_model::<C>(),
204        ]
205    }
206}
207
208// Implement for 4-tuple
209impl<A: Model, B: Model, C: Model, D: Model> ModelTuple for (A, B, C, D) {
210    fn all_table_schemas() -> Vec<TableInfo> {
211        vec![
212            table_schema_from_model::<A>(),
213            table_schema_from_model::<B>(),
214            table_schema_from_model::<C>(),
215            table_schema_from_model::<D>(),
216        ]
217    }
218}
219
220// Implement for 5-tuple
221impl<A: Model, B: Model, C: Model, D: Model, E: Model> ModelTuple for (A, B, C, D, E) {
222    fn all_table_schemas() -> Vec<TableInfo> {
223        vec![
224            table_schema_from_model::<A>(),
225            table_schema_from_model::<B>(),
226            table_schema_from_model::<C>(),
227            table_schema_from_model::<D>(),
228            table_schema_from_model::<E>(),
229        ]
230    }
231}
232
233// Implement for 6-tuple
234impl<A: Model, B: Model, C: Model, D: Model, E: Model, F: Model> ModelTuple for (A, B, C, D, E, F) {
235    fn all_table_schemas() -> Vec<TableInfo> {
236        vec![
237            table_schema_from_model::<A>(),
238            table_schema_from_model::<B>(),
239            table_schema_from_model::<C>(),
240            table_schema_from_model::<D>(),
241            table_schema_from_model::<E>(),
242            table_schema_from_model::<F>(),
243        ]
244    }
245}
246
247// Implement for 7-tuple
248impl<A: Model, B: Model, C: Model, D: Model, E: Model, F: Model, G: Model> ModelTuple
249    for (A, B, C, D, E, F, G)
250{
251    fn all_table_schemas() -> Vec<TableInfo> {
252        vec![
253            table_schema_from_model::<A>(),
254            table_schema_from_model::<B>(),
255            table_schema_from_model::<C>(),
256            table_schema_from_model::<D>(),
257            table_schema_from_model::<E>(),
258            table_schema_from_model::<F>(),
259            table_schema_from_model::<G>(),
260        ]
261    }
262}
263
264// Implement for 8-tuple
265impl<A: Model, B: Model, C: Model, D: Model, E: Model, F: Model, G: Model, H: Model> ModelTuple
266    for (A, B, C, D, E, F, G, H)
267{
268    fn all_table_schemas() -> Vec<TableInfo> {
269        vec![
270            table_schema_from_model::<A>(),
271            table_schema_from_model::<B>(),
272            table_schema_from_model::<C>(),
273            table_schema_from_model::<D>(),
274            table_schema_from_model::<E>(),
275            table_schema_from_model::<F>(),
276            table_schema_from_model::<G>(),
277            table_schema_from_model::<H>(),
278        ]
279    }
280}
281
282// ============================================================================
283// Type Normalization
284// ============================================================================
285
286/// Normalize a SQL type for comparison across dialects.
287///
288/// This handles common type aliases and dialect-specific variations
289/// to enable meaningful comparison between expected and actual schemas.
290pub fn normalize_sql_type(sql_type: &str, dialect: Dialect) -> String {
291    let upper = sql_type.to_uppercase();
292
293    match dialect {
294        Dialect::Sqlite => {
295            // SQLite type affinity normalization
296            if upper.contains("INT") {
297                "INTEGER".to_string()
298            } else if upper.contains("CHAR") || upper.contains("TEXT") || upper.contains("CLOB") {
299                "TEXT".to_string()
300            } else if upper.contains("REAL") || upper.contains("FLOAT") || upper.contains("DOUB") {
301                "REAL".to_string()
302            } else if upper.contains("BLOB") || upper.is_empty() {
303                "BLOB".to_string()
304            } else {
305                // Numeric affinity for anything else
306                upper
307            }
308        }
309        Dialect::Postgres => {
310            // PostgreSQL type normalizations
311            match upper.as_str() {
312                "INT" | "INT4" => "INTEGER".to_string(),
313                "INT8" => "BIGINT".to_string(),
314                "INT2" => "SMALLINT".to_string(),
315                "FLOAT4" => "REAL".to_string(),
316                "FLOAT8" => "DOUBLE PRECISION".to_string(),
317                "BOOL" => "BOOLEAN".to_string(),
318                "SERIAL" => "INTEGER".to_string(), // Serial is INTEGER with sequence
319                "BIGSERIAL" => "BIGINT".to_string(),
320                "SMALLSERIAL" => "SMALLINT".to_string(),
321                _ => upper,
322            }
323        }
324        Dialect::Mysql => {
325            // MySQL type normalizations
326            match upper.as_str() {
327                "INTEGER" => "INT".to_string(),
328                "BOOL" | "BOOLEAN" => "TINYINT".to_string(),
329                _ => upper,
330            }
331        }
332    }
333}
334
335// ============================================================================
336// Unit Tests
337// ============================================================================
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use sqlmodel_core::{ReferentialAction, Row, SqlType, Value};
343
344    // Test model
345    struct TestHero;
346
347    impl Model for TestHero {
348        const TABLE_NAME: &'static str = "heroes";
349        const PRIMARY_KEY: &'static [&'static str] = &["id"];
350
351        fn fields() -> &'static [FieldInfo] {
352            static FIELDS: &[FieldInfo] = &[
353                FieldInfo::new("id", "id", SqlType::BigInt)
354                    .nullable(true)
355                    .primary_key(true)
356                    .auto_increment(true),
357                FieldInfo::new("name", "name", SqlType::Text)
358                    .sql_type_override("VARCHAR(100)")
359                    .unique(true),
360                FieldInfo::new("age", "age", SqlType::Integer)
361                    .nullable(true)
362                    .index("idx_heroes_age"),
363                FieldInfo::new("team_id", "team_id", SqlType::BigInt)
364                    .nullable(true)
365                    .foreign_key("teams.id")
366                    .on_delete(ReferentialAction::Cascade),
367            ];
368            FIELDS
369        }
370
371        fn to_row(&self) -> Vec<(&'static str, Value)> {
372            vec![]
373        }
374
375        fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
376            Ok(TestHero)
377        }
378
379        fn primary_key_value(&self) -> Vec<Value> {
380            vec![]
381        }
382
383        fn is_new(&self) -> bool {
384            true
385        }
386    }
387
388    #[test]
389    fn test_model_schema_table_name() {
390        let schema = TestHero::table_schema();
391        assert_eq!(schema.name, "heroes");
392    }
393
394    #[test]
395    fn test_model_schema_columns() {
396        let schema = TestHero::table_schema();
397        assert_eq!(schema.columns.len(), 4);
398
399        let id_col = schema.column("id").unwrap();
400        assert_eq!(id_col.sql_type, "BIGINT");
401        assert!(id_col.primary_key);
402        assert!(id_col.auto_increment);
403
404        let name_col = schema.column("name").unwrap();
405        assert_eq!(name_col.sql_type, "VARCHAR(100)");
406        assert!(!name_col.nullable);
407    }
408
409    #[test]
410    fn test_model_schema_primary_key() {
411        let schema = TestHero::table_schema();
412        assert_eq!(schema.primary_key, vec!["id"]);
413    }
414
415    #[test]
416    fn test_model_schema_foreign_keys() {
417        let schema = TestHero::table_schema();
418        assert_eq!(schema.foreign_keys.len(), 1);
419
420        let fk = &schema.foreign_keys[0];
421        assert_eq!(fk.column, "team_id");
422        assert_eq!(fk.foreign_table, "teams");
423        assert_eq!(fk.foreign_column, "id");
424        assert_eq!(fk.on_delete, Some("CASCADE".to_string()));
425    }
426
427    #[test]
428    fn test_model_schema_unique_constraints() {
429        let schema = TestHero::table_schema();
430        assert_eq!(schema.unique_constraints.len(), 1);
431
432        let uk = &schema.unique_constraints[0];
433        assert_eq!(uk.columns, vec!["name"]);
434    }
435
436    #[test]
437    fn test_model_schema_indexes() {
438        let schema = TestHero::table_schema();
439        assert_eq!(schema.indexes.len(), 1);
440
441        let idx = &schema.indexes[0];
442        assert_eq!(idx.name, "idx_heroes_age");
443        assert_eq!(idx.columns, vec!["age"]);
444        assert!(!idx.unique);
445    }
446
447    #[test]
448    fn test_expected_schema() {
449        let schema = expected_schema::<TestHero>(Dialect::Sqlite);
450        assert_eq!(schema.dialect, Dialect::Sqlite);
451        assert!(schema.table("heroes").is_some());
452    }
453
454    #[test]
455    fn test_model_tuple_two() {
456        struct TestTeam;
457
458        impl Model for TestTeam {
459            const TABLE_NAME: &'static str = "teams";
460            const PRIMARY_KEY: &'static [&'static str] = &["id"];
461
462            fn fields() -> &'static [FieldInfo] {
463                static FIELDS: &[FieldInfo] = &[FieldInfo::new("id", "id", SqlType::BigInt)
464                    .nullable(true)
465                    .primary_key(true)
466                    .auto_increment(true)];
467                FIELDS
468            }
469
470            fn to_row(&self) -> Vec<(&'static str, Value)> {
471                vec![]
472            }
473
474            fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
475                Ok(TestTeam)
476            }
477
478            fn primary_key_value(&self) -> Vec<Value> {
479                vec![]
480            }
481
482            fn is_new(&self) -> bool {
483                true
484            }
485        }
486
487        let schema = <(TestHero, TestTeam)>::database_schema(Dialect::Postgres);
488        assert_eq!(schema.tables.len(), 2);
489        assert!(schema.table("heroes").is_some());
490        assert!(schema.table("teams").is_some());
491    }
492
493    #[test]
494    fn test_normalize_sql_type_sqlite() {
495        assert_eq!(normalize_sql_type("INTEGER", Dialect::Sqlite), "INTEGER");
496        assert_eq!(normalize_sql_type("INT", Dialect::Sqlite), "INTEGER");
497        assert_eq!(normalize_sql_type("BIGINT", Dialect::Sqlite), "INTEGER");
498        assert_eq!(normalize_sql_type("VARCHAR(100)", Dialect::Sqlite), "TEXT");
499        assert_eq!(normalize_sql_type("TEXT", Dialect::Sqlite), "TEXT");
500        assert_eq!(normalize_sql_type("REAL", Dialect::Sqlite), "REAL");
501        assert_eq!(normalize_sql_type("FLOAT", Dialect::Sqlite), "REAL");
502    }
503
504    #[test]
505    fn test_normalize_sql_type_postgres() {
506        assert_eq!(normalize_sql_type("INT", Dialect::Postgres), "INTEGER");
507        assert_eq!(normalize_sql_type("INT4", Dialect::Postgres), "INTEGER");
508        assert_eq!(normalize_sql_type("INT8", Dialect::Postgres), "BIGINT");
509        assert_eq!(
510            normalize_sql_type("FLOAT8", Dialect::Postgres),
511            "DOUBLE PRECISION"
512        );
513        assert_eq!(normalize_sql_type("BOOL", Dialect::Postgres), "BOOLEAN");
514        assert_eq!(normalize_sql_type("SERIAL", Dialect::Postgres), "INTEGER");
515    }
516
517    #[test]
518    fn test_normalize_sql_type_mysql() {
519        assert_eq!(normalize_sql_type("INTEGER", Dialect::Mysql), "INT");
520        assert_eq!(normalize_sql_type("BOOLEAN", Dialect::Mysql), "TINYINT");
521        assert_eq!(normalize_sql_type("BOOL", Dialect::Mysql), "TINYINT");
522    }
523
524    #[test]
525    fn test_parse_fk_reference() {
526        assert_eq!(
527            parse_fk_reference("users.id"),
528            Some(("users".to_string(), "id".to_string()))
529        );
530        assert_eq!(
531            parse_fk_reference("teams.team_id"),
532            Some(("teams".to_string(), "team_id".to_string()))
533        );
534        assert_eq!(parse_fk_reference("invalid"), None);
535        assert_eq!(parse_fk_reference("too.many.parts"), None);
536    }
537}