Skip to main content

sql_orm_migrate/
snapshot.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
2use sql_orm_core::{
3    ColumnMetadata, EntityMetadata, ForeignKeyMetadata, IdentityMetadata, IndexColumnMetadata,
4    IndexMetadata, ReferentialAction, SqlServerType,
5};
6use std::collections::BTreeMap;
7
8/// Serializable model snapshot shape used by future migration history artifacts.
9#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
10pub struct ModelSnapshot {
11    pub schemas: Vec<SchemaSnapshot>,
12}
13
14impl ModelSnapshot {
15    pub fn new(schemas: Vec<SchemaSnapshot>) -> Self {
16        Self { schemas }
17    }
18
19    pub fn from_entities(entities: &[&'static EntityMetadata]) -> Self {
20        let mut schemas = BTreeMap::<String, Vec<&'static EntityMetadata>>::new();
21
22        for entity in entities {
23            schemas
24                .entry(entity.schema.to_string())
25                .or_default()
26                .push(*entity);
27        }
28
29        let schemas = schemas
30            .into_iter()
31            .map(|(schema_name, mut entities)| {
32                entities.sort_by(|left, right| left.table.cmp(right.table));
33
34                SchemaSnapshot::new(
35                    schema_name,
36                    entities.into_iter().map(TableSnapshot::from).collect(),
37                )
38            })
39            .collect();
40
41        Self { schemas }
42    }
43
44    pub fn schema(&self, name: &str) -> Option<&SchemaSnapshot> {
45        self.schemas.iter().find(|schema| schema.name == name)
46    }
47
48    pub fn to_json_pretty(&self) -> Result<String, sql_orm_core::OrmError> {
49        serde_json::to_string_pretty(self)
50            .map(|json| format!("{json}\n"))
51            .map_err(|_| sql_orm_core::OrmError::migration("failed to serialize model snapshot"))
52    }
53
54    pub fn from_json(json: &str) -> Result<Self, sql_orm_core::OrmError> {
55        serde_json::from_str(json)
56            .map_err(|_| sql_orm_core::OrmError::migration("failed to deserialize model snapshot"))
57    }
58}
59
60/// Snapshot of a SQL Server schema and the tables currently modeled inside it.
61#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
62pub struct SchemaSnapshot {
63    pub name: String,
64    pub tables: Vec<TableSnapshot>,
65}
66
67impl SchemaSnapshot {
68    pub fn new(name: impl Into<String>, tables: Vec<TableSnapshot>) -> Self {
69        Self {
70            name: name.into(),
71            tables,
72        }
73    }
74
75    pub fn table(&self, name: &str) -> Option<&TableSnapshot> {
76        self.tables.iter().find(|table| table.name == name)
77    }
78}
79
80/// Snapshot of a SQL Server table with the minimum structural information needed
81/// for the first migration diff passes.
82#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
83pub struct TableSnapshot {
84    pub name: String,
85    pub renamed_from: Option<String>,
86    pub columns: Vec<ColumnSnapshot>,
87    pub primary_key_name: Option<String>,
88    pub primary_key_columns: Vec<String>,
89    pub indexes: Vec<IndexSnapshot>,
90    pub foreign_keys: Vec<ForeignKeySnapshot>,
91}
92
93impl TableSnapshot {
94    pub fn new(
95        name: impl Into<String>,
96        columns: Vec<ColumnSnapshot>,
97        primary_key_name: Option<String>,
98        primary_key_columns: Vec<String>,
99        indexes: Vec<IndexSnapshot>,
100        foreign_keys: Vec<ForeignKeySnapshot>,
101    ) -> Self {
102        Self {
103            name: name.into(),
104            renamed_from: None,
105            columns,
106            primary_key_name,
107            primary_key_columns,
108            indexes,
109            foreign_keys,
110        }
111    }
112
113    pub fn column(&self, name: &str) -> Option<&ColumnSnapshot> {
114        self.columns.iter().find(|column| column.name == name)
115    }
116
117    pub fn with_renamed_from(mut self, renamed_from: impl Into<String>) -> Self {
118        self.renamed_from = Some(renamed_from.into());
119        self
120    }
121
122    pub fn index(&self, name: &str) -> Option<&IndexSnapshot> {
123        self.indexes.iter().find(|index| index.name == name)
124    }
125
126    pub fn foreign_key(&self, name: &str) -> Option<&ForeignKeySnapshot> {
127        self.foreign_keys
128            .iter()
129            .find(|foreign_key| foreign_key.name == name)
130    }
131}
132
133/// Snapshot of a table column, aligned with the code-first metadata already
134/// defined in `sql-orm-core`.
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct ColumnSnapshot {
137    pub name: String,
138    pub renamed_from: Option<String>,
139    #[serde(with = "sql_server_type_json")]
140    pub sql_type: SqlServerType,
141    pub nullable: bool,
142    pub primary_key: bool,
143    #[serde(with = "identity_json")]
144    pub identity: Option<IdentityMetadata>,
145    pub default_sql: Option<String>,
146    pub computed_sql: Option<String>,
147    pub rowversion: bool,
148    pub insertable: bool,
149    pub updatable: bool,
150    pub max_length: Option<u32>,
151    pub precision: Option<u8>,
152    pub scale: Option<u8>,
153}
154
155impl ColumnSnapshot {
156    #[allow(clippy::too_many_arguments)]
157    pub fn new(
158        name: impl Into<String>,
159        sql_type: SqlServerType,
160        nullable: bool,
161        primary_key: bool,
162        identity: Option<IdentityMetadata>,
163        default_sql: Option<String>,
164        computed_sql: Option<String>,
165        rowversion: bool,
166        insertable: bool,
167        updatable: bool,
168        max_length: Option<u32>,
169        precision: Option<u8>,
170        scale: Option<u8>,
171    ) -> Self {
172        Self {
173            name: name.into(),
174            renamed_from: None,
175            sql_type,
176            nullable,
177            primary_key,
178            identity,
179            default_sql,
180            computed_sql,
181            rowversion,
182            insertable,
183            updatable,
184            max_length,
185            precision,
186            scale,
187        }
188    }
189
190    pub fn with_renamed_from(mut self, renamed_from: impl Into<String>) -> Self {
191        self.renamed_from = Some(renamed_from.into());
192        self
193    }
194}
195
196impl From<&ColumnMetadata> for ColumnSnapshot {
197    fn from(column: &ColumnMetadata) -> Self {
198        Self {
199            name: column.column_name.to_string(),
200            renamed_from: column.renamed_from.map(str::to_owned),
201            sql_type: column.sql_type,
202            nullable: column.nullable,
203            primary_key: column.primary_key,
204            identity: column.identity,
205            default_sql: column.default_sql.map(str::to_owned),
206            computed_sql: column.computed_sql.map(str::to_owned),
207            rowversion: column.rowversion,
208            insertable: column.insertable,
209            updatable: column.updatable,
210            max_length: column.max_length,
211            precision: column.precision,
212            scale: column.scale,
213        }
214    }
215}
216
217/// Snapshot of an index, including the participating columns and sort order.
218#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
219pub struct IndexSnapshot {
220    pub name: String,
221    pub columns: Vec<IndexColumnSnapshot>,
222    pub unique: bool,
223}
224
225impl IndexSnapshot {
226    pub fn new(name: impl Into<String>, columns: Vec<IndexColumnSnapshot>, unique: bool) -> Self {
227        Self {
228            name: name.into(),
229            columns,
230            unique,
231        }
232    }
233}
234
235impl From<&IndexMetadata> for IndexSnapshot {
236    fn from(index: &IndexMetadata) -> Self {
237        Self {
238            name: index.name.to_string(),
239            columns: index
240                .columns
241                .iter()
242                .map(IndexColumnSnapshot::from)
243                .collect(),
244            unique: index.unique,
245        }
246    }
247}
248
249/// Snapshot of a column inside an index definition.
250#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251pub struct IndexColumnSnapshot {
252    pub column_name: String,
253    pub descending: bool,
254}
255
256impl IndexColumnSnapshot {
257    pub fn asc(column_name: impl Into<String>) -> Self {
258        Self {
259            column_name: column_name.into(),
260            descending: false,
261        }
262    }
263
264    pub fn desc(column_name: impl Into<String>) -> Self {
265        Self {
266            column_name: column_name.into(),
267            descending: true,
268        }
269    }
270}
271
272impl From<&IndexColumnMetadata> for IndexColumnSnapshot {
273    fn from(column: &IndexColumnMetadata) -> Self {
274        Self {
275            column_name: column.column_name.to_string(),
276            descending: column.descending,
277        }
278    }
279}
280
281/// Snapshot of a foreign key, including referenced target and referential actions.
282#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
283pub struct ForeignKeySnapshot {
284    pub name: String,
285    pub columns: Vec<String>,
286    pub referenced_schema: String,
287    pub referenced_table: String,
288    pub referenced_columns: Vec<String>,
289    #[serde(with = "referential_action_json")]
290    pub on_delete: ReferentialAction,
291    #[serde(with = "referential_action_json")]
292    pub on_update: ReferentialAction,
293}
294
295impl ForeignKeySnapshot {
296    #[allow(clippy::too_many_arguments)]
297    pub fn new(
298        name: impl Into<String>,
299        columns: Vec<String>,
300        referenced_schema: impl Into<String>,
301        referenced_table: impl Into<String>,
302        referenced_columns: Vec<String>,
303        on_delete: ReferentialAction,
304        on_update: ReferentialAction,
305    ) -> Self {
306        Self {
307            name: name.into(),
308            columns,
309            referenced_schema: referenced_schema.into(),
310            referenced_table: referenced_table.into(),
311            referenced_columns,
312            on_delete,
313            on_update,
314        }
315    }
316}
317
318impl From<&ForeignKeyMetadata> for ForeignKeySnapshot {
319    fn from(foreign_key: &ForeignKeyMetadata) -> Self {
320        Self {
321            name: foreign_key.name.to_string(),
322            columns: foreign_key
323                .columns
324                .iter()
325                .map(|column| (*column).to_string())
326                .collect(),
327            referenced_schema: foreign_key.referenced_schema.to_string(),
328            referenced_table: foreign_key.referenced_table.to_string(),
329            referenced_columns: foreign_key
330                .referenced_columns
331                .iter()
332                .map(|column| (*column).to_string())
333                .collect(),
334            on_delete: foreign_key.on_delete,
335            on_update: foreign_key.on_update,
336        }
337    }
338}
339
340impl From<&EntityMetadata> for TableSnapshot {
341    fn from(entity: &EntityMetadata) -> Self {
342        Self {
343            name: entity.table.to_string(),
344            renamed_from: entity.renamed_from.map(str::to_owned),
345            columns: entity.columns.iter().map(ColumnSnapshot::from).collect(),
346            primary_key_name: entity.primary_key.name.map(str::to_owned),
347            primary_key_columns: entity
348                .primary_key
349                .columns
350                .iter()
351                .map(|column| (*column).to_string())
352                .collect(),
353            indexes: entity.indexes.iter().map(IndexSnapshot::from).collect(),
354            foreign_keys: entity
355                .foreign_keys
356                .iter()
357                .map(ForeignKeySnapshot::from)
358                .collect(),
359        }
360    }
361}
362
363mod identity_json {
364    use super::*;
365
366    #[derive(Serialize, Deserialize)]
367    struct IdentitySnapshot {
368        seed: i64,
369        increment: i64,
370    }
371
372    pub fn serialize<S>(
373        identity: &Option<IdentityMetadata>,
374        serializer: S,
375    ) -> Result<S::Ok, S::Error>
376    where
377        S: Serializer,
378    {
379        identity
380            .map(|identity| IdentitySnapshot {
381                seed: identity.seed,
382                increment: identity.increment,
383            })
384            .serialize(serializer)
385    }
386
387    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<IdentityMetadata>, D::Error>
388    where
389        D: Deserializer<'de>,
390    {
391        Option::<IdentitySnapshot>::deserialize(deserializer).map(|identity| {
392            identity.map(|identity| IdentityMetadata::new(identity.seed, identity.increment))
393        })
394    }
395}
396
397mod sql_server_type_json {
398    use super::*;
399
400    pub fn serialize<S>(sql_type: &SqlServerType, serializer: S) -> Result<S::Ok, S::Error>
401    where
402        S: Serializer,
403    {
404        match sql_type {
405            SqlServerType::Custom(value) => serializer.serialize_str(&format!("custom:{value}")),
406            other => serializer.serialize_str(to_str(other)),
407        }
408    }
409
410    pub fn deserialize<'de, D>(deserializer: D) -> Result<SqlServerType, D::Error>
411    where
412        D: Deserializer<'de>,
413    {
414        let value = String::deserialize(deserializer)?;
415        from_str(&value).ok_or_else(|| {
416            de::Error::custom(format!("unsupported SQL Server type in snapshot: {value}"))
417        })
418    }
419
420    fn to_str(sql_type: &SqlServerType) -> &str {
421        match sql_type {
422            SqlServerType::BigInt => "bigint",
423            SqlServerType::Int => "int",
424            SqlServerType::SmallInt => "smallint",
425            SqlServerType::TinyInt => "tinyint",
426            SqlServerType::Bit => "bit",
427            SqlServerType::UniqueIdentifier => "uniqueidentifier",
428            SqlServerType::Date => "date",
429            SqlServerType::Time => "time",
430            SqlServerType::DateTime2 => "datetime2",
431            SqlServerType::DateTimeOffset => "datetimeoffset",
432            SqlServerType::Decimal => "decimal",
433            SqlServerType::Float => "float",
434            SqlServerType::Money => "money",
435            SqlServerType::NVarChar => "nvarchar",
436            SqlServerType::VarBinary => "varbinary",
437            SqlServerType::RowVersion => "rowversion",
438            SqlServerType::Custom(value) => value,
439        }
440    }
441
442    fn from_str(value: &str) -> Option<SqlServerType> {
443        if let Some(custom) = value.strip_prefix("custom:") {
444            return if custom.is_empty() {
445                None
446            } else {
447                Some(SqlServerType::Custom(leak_static_str(custom)))
448            };
449        }
450
451        match value {
452            "bigint" => Some(SqlServerType::BigInt),
453            "int" => Some(SqlServerType::Int),
454            "smallint" => Some(SqlServerType::SmallInt),
455            "tinyint" => Some(SqlServerType::TinyInt),
456            "bit" => Some(SqlServerType::Bit),
457            "uniqueidentifier" => Some(SqlServerType::UniqueIdentifier),
458            "date" => Some(SqlServerType::Date),
459            "time" => Some(SqlServerType::Time),
460            "datetime2" => Some(SqlServerType::DateTime2),
461            "datetimeoffset" => Some(SqlServerType::DateTimeOffset),
462            "decimal" => Some(SqlServerType::Decimal),
463            "float" => Some(SqlServerType::Float),
464            "money" => Some(SqlServerType::Money),
465            "nvarchar" => Some(SqlServerType::NVarChar),
466            "varbinary" => Some(SqlServerType::VarBinary),
467            "rowversion" => Some(SqlServerType::RowVersion),
468            _ => None,
469        }
470    }
471
472    fn leak_static_str(value: &str) -> &'static str {
473        Box::leak(value.to_owned().into_boxed_str())
474    }
475}
476
477mod referential_action_json {
478    use super::*;
479
480    pub fn serialize<S>(action: &ReferentialAction, serializer: S) -> Result<S::Ok, S::Error>
481    where
482        S: Serializer,
483    {
484        serializer.serialize_str(match action {
485            ReferentialAction::NoAction => "no_action",
486            ReferentialAction::Cascade => "cascade",
487            ReferentialAction::SetNull => "set_null",
488            ReferentialAction::SetDefault => "set_default",
489        })
490    }
491
492    pub fn deserialize<'de, D>(deserializer: D) -> Result<ReferentialAction, D::Error>
493    where
494        D: Deserializer<'de>,
495    {
496        let value = String::deserialize(deserializer)?;
497        match value.as_str() {
498            "no_action" => Ok(ReferentialAction::NoAction),
499            "cascade" => Ok(ReferentialAction::Cascade),
500            "set_null" => Ok(ReferentialAction::SetNull),
501            "set_default" => Ok(ReferentialAction::SetDefault),
502            _ => Err(de::Error::custom(format!(
503                "unsupported referential action in snapshot: {value}"
504            ))),
505        }
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::{
512        ColumnSnapshot, ForeignKeySnapshot, IndexColumnSnapshot, IndexSnapshot, ModelSnapshot,
513        SchemaSnapshot, TableSnapshot,
514    };
515    use sql_orm_core::{IdentityMetadata, OrmErrorKind, ReferentialAction, SqlServerType};
516
517    #[test]
518    fn serializes_empty_model_snapshot_as_stable_json() {
519        let json = ModelSnapshot::default().to_json_pretty().unwrap();
520
521        assert_eq!(json, "{\n  \"schemas\": []\n}\n");
522        assert_eq!(
523            ModelSnapshot::from_json(&json).unwrap(),
524            ModelSnapshot::default()
525        );
526    }
527
528    #[test]
529    fn classifies_invalid_model_snapshot_json_as_migration_error() {
530        let error = ModelSnapshot::from_json("{").unwrap_err();
531
532        assert_eq!(error.kind(), OrmErrorKind::Migration);
533        assert_eq!(error.message(), "failed to deserialize model snapshot");
534    }
535
536    #[test]
537    fn roundtrips_complete_model_snapshot_json() {
538        let snapshot = ModelSnapshot::new(vec![SchemaSnapshot::new(
539            "sales",
540            vec![TableSnapshot {
541                name: "orders".to_string(),
542                renamed_from: Some("legacy_orders".to_string()),
543                columns: vec![
544                    ColumnSnapshot::new(
545                        "id",
546                        SqlServerType::BigInt,
547                        false,
548                        true,
549                        Some(IdentityMetadata::new(1, 1)),
550                        None,
551                        None,
552                        false,
553                        false,
554                        false,
555                        None,
556                        None,
557                        None,
558                    ),
559                    ColumnSnapshot::new(
560                        "status",
561                        SqlServerType::Custom("varchar(24)"),
562                        false,
563                        false,
564                        None,
565                        Some("'open'".to_string()),
566                        None,
567                        false,
568                        true,
569                        true,
570                        Some(24),
571                        None,
572                        None,
573                    )
574                    .with_renamed_from("state"),
575                ],
576                primary_key_name: Some("pk_orders".to_string()),
577                primary_key_columns: vec!["id".to_string()],
578                indexes: vec![IndexSnapshot::new(
579                    "ix_orders_status",
580                    vec![IndexColumnSnapshot::desc("status")],
581                    false,
582                )],
583                foreign_keys: vec![ForeignKeySnapshot::new(
584                    "fk_orders_customers",
585                    vec!["customer_id".to_string()],
586                    "sales",
587                    "customers",
588                    vec!["id".to_string()],
589                    ReferentialAction::Cascade,
590                    ReferentialAction::NoAction,
591                )],
592            }],
593        )]);
594
595        let json = snapshot.to_json_pretty().unwrap();
596        let parsed = ModelSnapshot::from_json(&json).unwrap();
597
598        assert_eq!(parsed, snapshot);
599        assert!(json.contains("\"sql_type\": \"custom:varchar(24)\""));
600        assert!(json.contains("\"on_delete\": \"cascade\""));
601    }
602}