Skip to main content

sql_orm_migrate/
snapshot.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
2use sql_orm_core::{
3    ColumnMetadata, EntityMetadata, ForeignKeyMetadata, IdentityMetadata, IndexColumnMetadata,
4    IndexMetadata, ReferentialAction, SqlServerType,
5};
6use std::collections::BTreeMap;
7
8/// Serializable model snapshot shape used by future migration history artifacts.
9#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
10pub struct ModelSnapshot {
11    pub schemas: Vec<SchemaSnapshot>,
12}
13
14impl ModelSnapshot {
15    pub fn new(schemas: Vec<SchemaSnapshot>) -> Self {
16        Self { schemas }
17    }
18
19    pub fn from_entities(entities: &[&'static EntityMetadata]) -> Self {
20        let mut schemas = BTreeMap::<String, Vec<&'static EntityMetadata>>::new();
21
22        for entity in entities {
23            schemas
24                .entry(entity.schema.to_string())
25                .or_default()
26                .push(*entity);
27        }
28
29        let schemas = schemas
30            .into_iter()
31            .map(|(schema_name, mut entities)| {
32                entities.sort_by(|left, right| left.table.cmp(right.table));
33
34                SchemaSnapshot::new(
35                    schema_name,
36                    entities.into_iter().map(TableSnapshot::from).collect(),
37                )
38            })
39            .collect();
40
41        Self { schemas }
42    }
43
44    pub fn schema(&self, name: &str) -> Option<&SchemaSnapshot> {
45        self.schemas.iter().find(|schema| schema.name == name)
46    }
47
48    pub fn to_json_pretty(&self) -> Result<String, sql_orm_core::OrmError> {
49        serde_json::to_string_pretty(self)
50            .map(|json| format!("{json}\n"))
51            .map_err(|_| sql_orm_core::OrmError::new("failed to serialize model snapshot"))
52    }
53
54    pub fn from_json(json: &str) -> Result<Self, sql_orm_core::OrmError> {
55        serde_json::from_str(json)
56            .map_err(|_| sql_orm_core::OrmError::new("failed to deserialize model snapshot"))
57    }
58}
59
60/// Snapshot of a SQL Server schema and the tables currently modeled inside it.
61#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
62pub struct SchemaSnapshot {
63    pub name: String,
64    pub tables: Vec<TableSnapshot>,
65}
66
67impl SchemaSnapshot {
68    pub fn new(name: impl Into<String>, tables: Vec<TableSnapshot>) -> Self {
69        Self {
70            name: name.into(),
71            tables,
72        }
73    }
74
75    pub fn table(&self, name: &str) -> Option<&TableSnapshot> {
76        self.tables.iter().find(|table| table.name == name)
77    }
78}
79
80/// Snapshot of a SQL Server table with the minimum structural information needed
81/// for the first migration diff passes.
82#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
83pub struct TableSnapshot {
84    pub name: String,
85    pub renamed_from: Option<String>,
86    pub columns: Vec<ColumnSnapshot>,
87    pub primary_key_name: Option<String>,
88    pub primary_key_columns: Vec<String>,
89    pub indexes: Vec<IndexSnapshot>,
90    pub foreign_keys: Vec<ForeignKeySnapshot>,
91}
92
93impl TableSnapshot {
94    pub fn new(
95        name: impl Into<String>,
96        columns: Vec<ColumnSnapshot>,
97        primary_key_name: Option<String>,
98        primary_key_columns: Vec<String>,
99        indexes: Vec<IndexSnapshot>,
100        foreign_keys: Vec<ForeignKeySnapshot>,
101    ) -> Self {
102        Self {
103            name: name.into(),
104            renamed_from: None,
105            columns,
106            primary_key_name,
107            primary_key_columns,
108            indexes,
109            foreign_keys,
110        }
111    }
112
113    pub fn column(&self, name: &str) -> Option<&ColumnSnapshot> {
114        self.columns.iter().find(|column| column.name == name)
115    }
116
117    pub fn with_renamed_from(mut self, renamed_from: impl Into<String>) -> Self {
118        self.renamed_from = Some(renamed_from.into());
119        self
120    }
121
122    pub fn index(&self, name: &str) -> Option<&IndexSnapshot> {
123        self.indexes.iter().find(|index| index.name == name)
124    }
125
126    pub fn foreign_key(&self, name: &str) -> Option<&ForeignKeySnapshot> {
127        self.foreign_keys
128            .iter()
129            .find(|foreign_key| foreign_key.name == name)
130    }
131}
132
133/// Snapshot of a table column, aligned with the code-first metadata already
134/// defined in `sql-orm-core`.
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct ColumnSnapshot {
137    pub name: String,
138    pub renamed_from: Option<String>,
139    #[serde(with = "sql_server_type_json")]
140    pub sql_type: SqlServerType,
141    pub nullable: bool,
142    pub primary_key: bool,
143    #[serde(with = "identity_json")]
144    pub identity: Option<IdentityMetadata>,
145    pub default_sql: Option<String>,
146    pub computed_sql: Option<String>,
147    pub rowversion: bool,
148    pub insertable: bool,
149    pub updatable: bool,
150    pub max_length: Option<u32>,
151    pub precision: Option<u8>,
152    pub scale: Option<u8>,
153}
154
155impl ColumnSnapshot {
156    #[allow(clippy::too_many_arguments)]
157    pub fn new(
158        name: impl Into<String>,
159        sql_type: SqlServerType,
160        nullable: bool,
161        primary_key: bool,
162        identity: Option<IdentityMetadata>,
163        default_sql: Option<String>,
164        computed_sql: Option<String>,
165        rowversion: bool,
166        insertable: bool,
167        updatable: bool,
168        max_length: Option<u32>,
169        precision: Option<u8>,
170        scale: Option<u8>,
171    ) -> Self {
172        Self {
173            name: name.into(),
174            renamed_from: None,
175            sql_type,
176            nullable,
177            primary_key,
178            identity,
179            default_sql,
180            computed_sql,
181            rowversion,
182            insertable,
183            updatable,
184            max_length,
185            precision,
186            scale,
187        }
188    }
189
190    pub fn with_renamed_from(mut self, renamed_from: impl Into<String>) -> Self {
191        self.renamed_from = Some(renamed_from.into());
192        self
193    }
194}
195
196impl From<&ColumnMetadata> for ColumnSnapshot {
197    fn from(column: &ColumnMetadata) -> Self {
198        Self {
199            name: column.column_name.to_string(),
200            renamed_from: column.renamed_from.map(str::to_owned),
201            sql_type: column.sql_type,
202            nullable: column.nullable,
203            primary_key: column.primary_key,
204            identity: column.identity,
205            default_sql: column.default_sql.map(str::to_owned),
206            computed_sql: column.computed_sql.map(str::to_owned),
207            rowversion: column.rowversion,
208            insertable: column.insertable,
209            updatable: column.updatable,
210            max_length: column.max_length,
211            precision: column.precision,
212            scale: column.scale,
213        }
214    }
215}
216
217/// Snapshot of an index, including the participating columns and sort order.
218#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
219pub struct IndexSnapshot {
220    pub name: String,
221    pub columns: Vec<IndexColumnSnapshot>,
222    pub unique: bool,
223}
224
225impl IndexSnapshot {
226    pub fn new(name: impl Into<String>, columns: Vec<IndexColumnSnapshot>, unique: bool) -> Self {
227        Self {
228            name: name.into(),
229            columns,
230            unique,
231        }
232    }
233}
234
235impl From<&IndexMetadata> for IndexSnapshot {
236    fn from(index: &IndexMetadata) -> Self {
237        Self {
238            name: index.name.to_string(),
239            columns: index
240                .columns
241                .iter()
242                .map(IndexColumnSnapshot::from)
243                .collect(),
244            unique: index.unique,
245        }
246    }
247}
248
249/// Snapshot of a column inside an index definition.
250#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251pub struct IndexColumnSnapshot {
252    pub column_name: String,
253    pub descending: bool,
254}
255
256impl IndexColumnSnapshot {
257    pub fn asc(column_name: impl Into<String>) -> Self {
258        Self {
259            column_name: column_name.into(),
260            descending: false,
261        }
262    }
263
264    pub fn desc(column_name: impl Into<String>) -> Self {
265        Self {
266            column_name: column_name.into(),
267            descending: true,
268        }
269    }
270}
271
272impl From<&IndexColumnMetadata> for IndexColumnSnapshot {
273    fn from(column: &IndexColumnMetadata) -> Self {
274        Self {
275            column_name: column.column_name.to_string(),
276            descending: column.descending,
277        }
278    }
279}
280
281/// Snapshot of a foreign key, including referenced target and referential actions.
282#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
283pub struct ForeignKeySnapshot {
284    pub name: String,
285    pub columns: Vec<String>,
286    pub referenced_schema: String,
287    pub referenced_table: String,
288    pub referenced_columns: Vec<String>,
289    #[serde(with = "referential_action_json")]
290    pub on_delete: ReferentialAction,
291    #[serde(with = "referential_action_json")]
292    pub on_update: ReferentialAction,
293}
294
295impl ForeignKeySnapshot {
296    #[allow(clippy::too_many_arguments)]
297    pub fn new(
298        name: impl Into<String>,
299        columns: Vec<String>,
300        referenced_schema: impl Into<String>,
301        referenced_table: impl Into<String>,
302        referenced_columns: Vec<String>,
303        on_delete: ReferentialAction,
304        on_update: ReferentialAction,
305    ) -> Self {
306        Self {
307            name: name.into(),
308            columns,
309            referenced_schema: referenced_schema.into(),
310            referenced_table: referenced_table.into(),
311            referenced_columns,
312            on_delete,
313            on_update,
314        }
315    }
316}
317
318impl From<&ForeignKeyMetadata> for ForeignKeySnapshot {
319    fn from(foreign_key: &ForeignKeyMetadata) -> Self {
320        Self {
321            name: foreign_key.name.to_string(),
322            columns: foreign_key
323                .columns
324                .iter()
325                .map(|column| (*column).to_string())
326                .collect(),
327            referenced_schema: foreign_key.referenced_schema.to_string(),
328            referenced_table: foreign_key.referenced_table.to_string(),
329            referenced_columns: foreign_key
330                .referenced_columns
331                .iter()
332                .map(|column| (*column).to_string())
333                .collect(),
334            on_delete: foreign_key.on_delete,
335            on_update: foreign_key.on_update,
336        }
337    }
338}
339
340impl From<&EntityMetadata> for TableSnapshot {
341    fn from(entity: &EntityMetadata) -> Self {
342        Self {
343            name: entity.table.to_string(),
344            renamed_from: entity.renamed_from.map(str::to_owned),
345            columns: entity.columns.iter().map(ColumnSnapshot::from).collect(),
346            primary_key_name: entity.primary_key.name.map(str::to_owned),
347            primary_key_columns: entity
348                .primary_key
349                .columns
350                .iter()
351                .map(|column| (*column).to_string())
352                .collect(),
353            indexes: entity.indexes.iter().map(IndexSnapshot::from).collect(),
354            foreign_keys: entity
355                .foreign_keys
356                .iter()
357                .map(ForeignKeySnapshot::from)
358                .collect(),
359        }
360    }
361}
362
363mod identity_json {
364    use super::*;
365
366    #[derive(Serialize, Deserialize)]
367    struct IdentitySnapshot {
368        seed: i64,
369        increment: i64,
370    }
371
372    pub fn serialize<S>(
373        identity: &Option<IdentityMetadata>,
374        serializer: S,
375    ) -> Result<S::Ok, S::Error>
376    where
377        S: Serializer,
378    {
379        identity
380            .map(|identity| IdentitySnapshot {
381                seed: identity.seed,
382                increment: identity.increment,
383            })
384            .serialize(serializer)
385    }
386
387    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<IdentityMetadata>, D::Error>
388    where
389        D: Deserializer<'de>,
390    {
391        Option::<IdentitySnapshot>::deserialize(deserializer).map(|identity| {
392            identity.map(|identity| IdentityMetadata::new(identity.seed, identity.increment))
393        })
394    }
395}
396
397mod sql_server_type_json {
398    use super::*;
399
400    pub fn serialize<S>(sql_type: &SqlServerType, serializer: S) -> Result<S::Ok, S::Error>
401    where
402        S: Serializer,
403    {
404        match sql_type {
405            SqlServerType::Custom(value) => serializer.serialize_str(&format!("custom:{value}")),
406            other => serializer.serialize_str(to_str(other)),
407        }
408    }
409
410    pub fn deserialize<'de, D>(deserializer: D) -> Result<SqlServerType, D::Error>
411    where
412        D: Deserializer<'de>,
413    {
414        let value = String::deserialize(deserializer)?;
415        from_str(&value).ok_or_else(|| {
416            de::Error::custom(format!("unsupported SQL Server type in snapshot: {value}"))
417        })
418    }
419
420    fn to_str(sql_type: &SqlServerType) -> &str {
421        match sql_type {
422            SqlServerType::BigInt => "bigint",
423            SqlServerType::Int => "int",
424            SqlServerType::SmallInt => "smallint",
425            SqlServerType::TinyInt => "tinyint",
426            SqlServerType::Bit => "bit",
427            SqlServerType::UniqueIdentifier => "uniqueidentifier",
428            SqlServerType::Date => "date",
429            SqlServerType::DateTime2 => "datetime2",
430            SqlServerType::Decimal => "decimal",
431            SqlServerType::Float => "float",
432            SqlServerType::Money => "money",
433            SqlServerType::NVarChar => "nvarchar",
434            SqlServerType::VarBinary => "varbinary",
435            SqlServerType::RowVersion => "rowversion",
436            SqlServerType::Custom(value) => value,
437        }
438    }
439
440    fn from_str(value: &str) -> Option<SqlServerType> {
441        if let Some(custom) = value.strip_prefix("custom:") {
442            return if custom.is_empty() {
443                None
444            } else {
445                Some(SqlServerType::Custom(leak_static_str(custom)))
446            };
447        }
448
449        match value {
450            "bigint" => Some(SqlServerType::BigInt),
451            "int" => Some(SqlServerType::Int),
452            "smallint" => Some(SqlServerType::SmallInt),
453            "tinyint" => Some(SqlServerType::TinyInt),
454            "bit" => Some(SqlServerType::Bit),
455            "uniqueidentifier" => Some(SqlServerType::UniqueIdentifier),
456            "date" => Some(SqlServerType::Date),
457            "datetime2" => Some(SqlServerType::DateTime2),
458            "decimal" => Some(SqlServerType::Decimal),
459            "float" => Some(SqlServerType::Float),
460            "money" => Some(SqlServerType::Money),
461            "nvarchar" => Some(SqlServerType::NVarChar),
462            "varbinary" => Some(SqlServerType::VarBinary),
463            "rowversion" => Some(SqlServerType::RowVersion),
464            _ => None,
465        }
466    }
467
468    fn leak_static_str(value: &str) -> &'static str {
469        Box::leak(value.to_owned().into_boxed_str())
470    }
471}
472
473mod referential_action_json {
474    use super::*;
475
476    pub fn serialize<S>(action: &ReferentialAction, serializer: S) -> Result<S::Ok, S::Error>
477    where
478        S: Serializer,
479    {
480        serializer.serialize_str(match action {
481            ReferentialAction::NoAction => "no_action",
482            ReferentialAction::Cascade => "cascade",
483            ReferentialAction::SetNull => "set_null",
484            ReferentialAction::SetDefault => "set_default",
485        })
486    }
487
488    pub fn deserialize<'de, D>(deserializer: D) -> Result<ReferentialAction, D::Error>
489    where
490        D: Deserializer<'de>,
491    {
492        let value = String::deserialize(deserializer)?;
493        match value.as_str() {
494            "no_action" => Ok(ReferentialAction::NoAction),
495            "cascade" => Ok(ReferentialAction::Cascade),
496            "set_null" => Ok(ReferentialAction::SetNull),
497            "set_default" => Ok(ReferentialAction::SetDefault),
498            _ => Err(de::Error::custom(format!(
499                "unsupported referential action in snapshot: {value}"
500            ))),
501        }
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::{
508        ColumnSnapshot, ForeignKeySnapshot, IndexColumnSnapshot, IndexSnapshot, ModelSnapshot,
509        SchemaSnapshot, TableSnapshot,
510    };
511    use sql_orm_core::{IdentityMetadata, ReferentialAction, SqlServerType};
512
513    #[test]
514    fn serializes_empty_model_snapshot_as_stable_json() {
515        let json = ModelSnapshot::default().to_json_pretty().unwrap();
516
517        assert_eq!(json, "{\n  \"schemas\": []\n}\n");
518        assert_eq!(
519            ModelSnapshot::from_json(&json).unwrap(),
520            ModelSnapshot::default()
521        );
522    }
523
524    #[test]
525    fn roundtrips_complete_model_snapshot_json() {
526        let snapshot = ModelSnapshot::new(vec![SchemaSnapshot::new(
527            "sales",
528            vec![TableSnapshot {
529                name: "orders".to_string(),
530                renamed_from: Some("legacy_orders".to_string()),
531                columns: vec![
532                    ColumnSnapshot::new(
533                        "id",
534                        SqlServerType::BigInt,
535                        false,
536                        true,
537                        Some(IdentityMetadata::new(1, 1)),
538                        None,
539                        None,
540                        false,
541                        false,
542                        false,
543                        None,
544                        None,
545                        None,
546                    ),
547                    ColumnSnapshot::new(
548                        "status",
549                        SqlServerType::Custom("varchar(24)"),
550                        false,
551                        false,
552                        None,
553                        Some("'open'".to_string()),
554                        None,
555                        false,
556                        true,
557                        true,
558                        Some(24),
559                        None,
560                        None,
561                    )
562                    .with_renamed_from("state"),
563                ],
564                primary_key_name: Some("pk_orders".to_string()),
565                primary_key_columns: vec!["id".to_string()],
566                indexes: vec![IndexSnapshot::new(
567                    "ix_orders_status",
568                    vec![IndexColumnSnapshot::desc("status")],
569                    false,
570                )],
571                foreign_keys: vec![ForeignKeySnapshot::new(
572                    "fk_orders_customers",
573                    vec!["customer_id".to_string()],
574                    "sales",
575                    "customers",
576                    vec!["id".to_string()],
577                    ReferentialAction::Cascade,
578                    ReferentialAction::NoAction,
579                )],
580            }],
581        )]);
582
583        let json = snapshot.to_json_pretty().unwrap();
584        let parsed = ModelSnapshot::from_json(&json).unwrap();
585
586        assert_eq!(parsed, snapshot);
587        assert!(json.contains("\"sql_type\": \"custom:varchar(24)\""));
588        assert!(json.contains("\"on_delete\": \"cascade\""));
589    }
590}