1use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
2use sql_orm_core::{
3 ColumnMetadata, EntityMetadata, ForeignKeyMetadata, IdentityMetadata, IndexColumnMetadata,
4 IndexMetadata, ReferentialAction, SqlServerType,
5};
6use std::collections::BTreeMap;
7
8#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
10pub struct ModelSnapshot {
11 pub schemas: Vec<SchemaSnapshot>,
12}
13
14impl ModelSnapshot {
15 pub fn new(schemas: Vec<SchemaSnapshot>) -> Self {
16 Self { schemas }
17 }
18
19 pub fn from_entities(entities: &[&'static EntityMetadata]) -> Self {
20 let mut schemas = BTreeMap::<String, Vec<&'static EntityMetadata>>::new();
21
22 for entity in entities {
23 schemas
24 .entry(entity.schema.to_string())
25 .or_default()
26 .push(*entity);
27 }
28
29 let schemas = schemas
30 .into_iter()
31 .map(|(schema_name, mut entities)| {
32 entities.sort_by(|left, right| left.table.cmp(right.table));
33
34 SchemaSnapshot::new(
35 schema_name,
36 entities.into_iter().map(TableSnapshot::from).collect(),
37 )
38 })
39 .collect();
40
41 Self { schemas }
42 }
43
44 pub fn schema(&self, name: &str) -> Option<&SchemaSnapshot> {
45 self.schemas.iter().find(|schema| schema.name == name)
46 }
47
48 pub fn to_json_pretty(&self) -> Result<String, sql_orm_core::OrmError> {
49 serde_json::to_string_pretty(self)
50 .map(|json| format!("{json}\n"))
51 .map_err(|_| sql_orm_core::OrmError::new("failed to serialize model snapshot"))
52 }
53
54 pub fn from_json(json: &str) -> Result<Self, sql_orm_core::OrmError> {
55 serde_json::from_str(json)
56 .map_err(|_| sql_orm_core::OrmError::new("failed to deserialize model snapshot"))
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
62pub struct SchemaSnapshot {
63 pub name: String,
64 pub tables: Vec<TableSnapshot>,
65}
66
67impl SchemaSnapshot {
68 pub fn new(name: impl Into<String>, tables: Vec<TableSnapshot>) -> Self {
69 Self {
70 name: name.into(),
71 tables,
72 }
73 }
74
75 pub fn table(&self, name: &str) -> Option<&TableSnapshot> {
76 self.tables.iter().find(|table| table.name == name)
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
83pub struct TableSnapshot {
84 pub name: String,
85 pub renamed_from: Option<String>,
86 pub columns: Vec<ColumnSnapshot>,
87 pub primary_key_name: Option<String>,
88 pub primary_key_columns: Vec<String>,
89 pub indexes: Vec<IndexSnapshot>,
90 pub foreign_keys: Vec<ForeignKeySnapshot>,
91}
92
93impl TableSnapshot {
94 pub fn new(
95 name: impl Into<String>,
96 columns: Vec<ColumnSnapshot>,
97 primary_key_name: Option<String>,
98 primary_key_columns: Vec<String>,
99 indexes: Vec<IndexSnapshot>,
100 foreign_keys: Vec<ForeignKeySnapshot>,
101 ) -> Self {
102 Self {
103 name: name.into(),
104 renamed_from: None,
105 columns,
106 primary_key_name,
107 primary_key_columns,
108 indexes,
109 foreign_keys,
110 }
111 }
112
113 pub fn column(&self, name: &str) -> Option<&ColumnSnapshot> {
114 self.columns.iter().find(|column| column.name == name)
115 }
116
117 pub fn with_renamed_from(mut self, renamed_from: impl Into<String>) -> Self {
118 self.renamed_from = Some(renamed_from.into());
119 self
120 }
121
122 pub fn index(&self, name: &str) -> Option<&IndexSnapshot> {
123 self.indexes.iter().find(|index| index.name == name)
124 }
125
126 pub fn foreign_key(&self, name: &str) -> Option<&ForeignKeySnapshot> {
127 self.foreign_keys
128 .iter()
129 .find(|foreign_key| foreign_key.name == name)
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct ColumnSnapshot {
137 pub name: String,
138 pub renamed_from: Option<String>,
139 #[serde(with = "sql_server_type_json")]
140 pub sql_type: SqlServerType,
141 pub nullable: bool,
142 pub primary_key: bool,
143 #[serde(with = "identity_json")]
144 pub identity: Option<IdentityMetadata>,
145 pub default_sql: Option<String>,
146 pub computed_sql: Option<String>,
147 pub rowversion: bool,
148 pub insertable: bool,
149 pub updatable: bool,
150 pub max_length: Option<u32>,
151 pub precision: Option<u8>,
152 pub scale: Option<u8>,
153}
154
155impl ColumnSnapshot {
156 #[allow(clippy::too_many_arguments)]
157 pub fn new(
158 name: impl Into<String>,
159 sql_type: SqlServerType,
160 nullable: bool,
161 primary_key: bool,
162 identity: Option<IdentityMetadata>,
163 default_sql: Option<String>,
164 computed_sql: Option<String>,
165 rowversion: bool,
166 insertable: bool,
167 updatable: bool,
168 max_length: Option<u32>,
169 precision: Option<u8>,
170 scale: Option<u8>,
171 ) -> Self {
172 Self {
173 name: name.into(),
174 renamed_from: None,
175 sql_type,
176 nullable,
177 primary_key,
178 identity,
179 default_sql,
180 computed_sql,
181 rowversion,
182 insertable,
183 updatable,
184 max_length,
185 precision,
186 scale,
187 }
188 }
189
190 pub fn with_renamed_from(mut self, renamed_from: impl Into<String>) -> Self {
191 self.renamed_from = Some(renamed_from.into());
192 self
193 }
194}
195
196impl From<&ColumnMetadata> for ColumnSnapshot {
197 fn from(column: &ColumnMetadata) -> Self {
198 Self {
199 name: column.column_name.to_string(),
200 renamed_from: column.renamed_from.map(str::to_owned),
201 sql_type: column.sql_type,
202 nullable: column.nullable,
203 primary_key: column.primary_key,
204 identity: column.identity,
205 default_sql: column.default_sql.map(str::to_owned),
206 computed_sql: column.computed_sql.map(str::to_owned),
207 rowversion: column.rowversion,
208 insertable: column.insertable,
209 updatable: column.updatable,
210 max_length: column.max_length,
211 precision: column.precision,
212 scale: column.scale,
213 }
214 }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
219pub struct IndexSnapshot {
220 pub name: String,
221 pub columns: Vec<IndexColumnSnapshot>,
222 pub unique: bool,
223}
224
225impl IndexSnapshot {
226 pub fn new(name: impl Into<String>, columns: Vec<IndexColumnSnapshot>, unique: bool) -> Self {
227 Self {
228 name: name.into(),
229 columns,
230 unique,
231 }
232 }
233}
234
235impl From<&IndexMetadata> for IndexSnapshot {
236 fn from(index: &IndexMetadata) -> Self {
237 Self {
238 name: index.name.to_string(),
239 columns: index
240 .columns
241 .iter()
242 .map(IndexColumnSnapshot::from)
243 .collect(),
244 unique: index.unique,
245 }
246 }
247}
248
249#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251pub struct IndexColumnSnapshot {
252 pub column_name: String,
253 pub descending: bool,
254}
255
256impl IndexColumnSnapshot {
257 pub fn asc(column_name: impl Into<String>) -> Self {
258 Self {
259 column_name: column_name.into(),
260 descending: false,
261 }
262 }
263
264 pub fn desc(column_name: impl Into<String>) -> Self {
265 Self {
266 column_name: column_name.into(),
267 descending: true,
268 }
269 }
270}
271
272impl From<&IndexColumnMetadata> for IndexColumnSnapshot {
273 fn from(column: &IndexColumnMetadata) -> Self {
274 Self {
275 column_name: column.column_name.to_string(),
276 descending: column.descending,
277 }
278 }
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
283pub struct ForeignKeySnapshot {
284 pub name: String,
285 pub columns: Vec<String>,
286 pub referenced_schema: String,
287 pub referenced_table: String,
288 pub referenced_columns: Vec<String>,
289 #[serde(with = "referential_action_json")]
290 pub on_delete: ReferentialAction,
291 #[serde(with = "referential_action_json")]
292 pub on_update: ReferentialAction,
293}
294
295impl ForeignKeySnapshot {
296 #[allow(clippy::too_many_arguments)]
297 pub fn new(
298 name: impl Into<String>,
299 columns: Vec<String>,
300 referenced_schema: impl Into<String>,
301 referenced_table: impl Into<String>,
302 referenced_columns: Vec<String>,
303 on_delete: ReferentialAction,
304 on_update: ReferentialAction,
305 ) -> Self {
306 Self {
307 name: name.into(),
308 columns,
309 referenced_schema: referenced_schema.into(),
310 referenced_table: referenced_table.into(),
311 referenced_columns,
312 on_delete,
313 on_update,
314 }
315 }
316}
317
318impl From<&ForeignKeyMetadata> for ForeignKeySnapshot {
319 fn from(foreign_key: &ForeignKeyMetadata) -> Self {
320 Self {
321 name: foreign_key.name.to_string(),
322 columns: foreign_key
323 .columns
324 .iter()
325 .map(|column| (*column).to_string())
326 .collect(),
327 referenced_schema: foreign_key.referenced_schema.to_string(),
328 referenced_table: foreign_key.referenced_table.to_string(),
329 referenced_columns: foreign_key
330 .referenced_columns
331 .iter()
332 .map(|column| (*column).to_string())
333 .collect(),
334 on_delete: foreign_key.on_delete,
335 on_update: foreign_key.on_update,
336 }
337 }
338}
339
340impl From<&EntityMetadata> for TableSnapshot {
341 fn from(entity: &EntityMetadata) -> Self {
342 Self {
343 name: entity.table.to_string(),
344 renamed_from: entity.renamed_from.map(str::to_owned),
345 columns: entity.columns.iter().map(ColumnSnapshot::from).collect(),
346 primary_key_name: entity.primary_key.name.map(str::to_owned),
347 primary_key_columns: entity
348 .primary_key
349 .columns
350 .iter()
351 .map(|column| (*column).to_string())
352 .collect(),
353 indexes: entity.indexes.iter().map(IndexSnapshot::from).collect(),
354 foreign_keys: entity
355 .foreign_keys
356 .iter()
357 .map(ForeignKeySnapshot::from)
358 .collect(),
359 }
360 }
361}
362
363mod identity_json {
364 use super::*;
365
366 #[derive(Serialize, Deserialize)]
367 struct IdentitySnapshot {
368 seed: i64,
369 increment: i64,
370 }
371
372 pub fn serialize<S>(
373 identity: &Option<IdentityMetadata>,
374 serializer: S,
375 ) -> Result<S::Ok, S::Error>
376 where
377 S: Serializer,
378 {
379 identity
380 .map(|identity| IdentitySnapshot {
381 seed: identity.seed,
382 increment: identity.increment,
383 })
384 .serialize(serializer)
385 }
386
387 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<IdentityMetadata>, D::Error>
388 where
389 D: Deserializer<'de>,
390 {
391 Option::<IdentitySnapshot>::deserialize(deserializer).map(|identity| {
392 identity.map(|identity| IdentityMetadata::new(identity.seed, identity.increment))
393 })
394 }
395}
396
397mod sql_server_type_json {
398 use super::*;
399
400 pub fn serialize<S>(sql_type: &SqlServerType, serializer: S) -> Result<S::Ok, S::Error>
401 where
402 S: Serializer,
403 {
404 match sql_type {
405 SqlServerType::Custom(value) => serializer.serialize_str(&format!("custom:{value}")),
406 other => serializer.serialize_str(to_str(other)),
407 }
408 }
409
410 pub fn deserialize<'de, D>(deserializer: D) -> Result<SqlServerType, D::Error>
411 where
412 D: Deserializer<'de>,
413 {
414 let value = String::deserialize(deserializer)?;
415 from_str(&value).ok_or_else(|| {
416 de::Error::custom(format!("unsupported SQL Server type in snapshot: {value}"))
417 })
418 }
419
420 fn to_str(sql_type: &SqlServerType) -> &str {
421 match sql_type {
422 SqlServerType::BigInt => "bigint",
423 SqlServerType::Int => "int",
424 SqlServerType::SmallInt => "smallint",
425 SqlServerType::TinyInt => "tinyint",
426 SqlServerType::Bit => "bit",
427 SqlServerType::UniqueIdentifier => "uniqueidentifier",
428 SqlServerType::Date => "date",
429 SqlServerType::DateTime2 => "datetime2",
430 SqlServerType::Decimal => "decimal",
431 SqlServerType::Float => "float",
432 SqlServerType::Money => "money",
433 SqlServerType::NVarChar => "nvarchar",
434 SqlServerType::VarBinary => "varbinary",
435 SqlServerType::RowVersion => "rowversion",
436 SqlServerType::Custom(value) => value,
437 }
438 }
439
440 fn from_str(value: &str) -> Option<SqlServerType> {
441 if let Some(custom) = value.strip_prefix("custom:") {
442 return if custom.is_empty() {
443 None
444 } else {
445 Some(SqlServerType::Custom(leak_static_str(custom)))
446 };
447 }
448
449 match value {
450 "bigint" => Some(SqlServerType::BigInt),
451 "int" => Some(SqlServerType::Int),
452 "smallint" => Some(SqlServerType::SmallInt),
453 "tinyint" => Some(SqlServerType::TinyInt),
454 "bit" => Some(SqlServerType::Bit),
455 "uniqueidentifier" => Some(SqlServerType::UniqueIdentifier),
456 "date" => Some(SqlServerType::Date),
457 "datetime2" => Some(SqlServerType::DateTime2),
458 "decimal" => Some(SqlServerType::Decimal),
459 "float" => Some(SqlServerType::Float),
460 "money" => Some(SqlServerType::Money),
461 "nvarchar" => Some(SqlServerType::NVarChar),
462 "varbinary" => Some(SqlServerType::VarBinary),
463 "rowversion" => Some(SqlServerType::RowVersion),
464 _ => None,
465 }
466 }
467
468 fn leak_static_str(value: &str) -> &'static str {
469 Box::leak(value.to_owned().into_boxed_str())
470 }
471}
472
473mod referential_action_json {
474 use super::*;
475
476 pub fn serialize<S>(action: &ReferentialAction, serializer: S) -> Result<S::Ok, S::Error>
477 where
478 S: Serializer,
479 {
480 serializer.serialize_str(match action {
481 ReferentialAction::NoAction => "no_action",
482 ReferentialAction::Cascade => "cascade",
483 ReferentialAction::SetNull => "set_null",
484 ReferentialAction::SetDefault => "set_default",
485 })
486 }
487
488 pub fn deserialize<'de, D>(deserializer: D) -> Result<ReferentialAction, D::Error>
489 where
490 D: Deserializer<'de>,
491 {
492 let value = String::deserialize(deserializer)?;
493 match value.as_str() {
494 "no_action" => Ok(ReferentialAction::NoAction),
495 "cascade" => Ok(ReferentialAction::Cascade),
496 "set_null" => Ok(ReferentialAction::SetNull),
497 "set_default" => Ok(ReferentialAction::SetDefault),
498 _ => Err(de::Error::custom(format!(
499 "unsupported referential action in snapshot: {value}"
500 ))),
501 }
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::{
508 ColumnSnapshot, ForeignKeySnapshot, IndexColumnSnapshot, IndexSnapshot, ModelSnapshot,
509 SchemaSnapshot, TableSnapshot,
510 };
511 use sql_orm_core::{IdentityMetadata, ReferentialAction, SqlServerType};
512
513 #[test]
514 fn serializes_empty_model_snapshot_as_stable_json() {
515 let json = ModelSnapshot::default().to_json_pretty().unwrap();
516
517 assert_eq!(json, "{\n \"schemas\": []\n}\n");
518 assert_eq!(
519 ModelSnapshot::from_json(&json).unwrap(),
520 ModelSnapshot::default()
521 );
522 }
523
524 #[test]
525 fn roundtrips_complete_model_snapshot_json() {
526 let snapshot = ModelSnapshot::new(vec![SchemaSnapshot::new(
527 "sales",
528 vec![TableSnapshot {
529 name: "orders".to_string(),
530 renamed_from: Some("legacy_orders".to_string()),
531 columns: vec![
532 ColumnSnapshot::new(
533 "id",
534 SqlServerType::BigInt,
535 false,
536 true,
537 Some(IdentityMetadata::new(1, 1)),
538 None,
539 None,
540 false,
541 false,
542 false,
543 None,
544 None,
545 None,
546 ),
547 ColumnSnapshot::new(
548 "status",
549 SqlServerType::Custom("varchar(24)"),
550 false,
551 false,
552 None,
553 Some("'open'".to_string()),
554 None,
555 false,
556 true,
557 true,
558 Some(24),
559 None,
560 None,
561 )
562 .with_renamed_from("state"),
563 ],
564 primary_key_name: Some("pk_orders".to_string()),
565 primary_key_columns: vec!["id".to_string()],
566 indexes: vec![IndexSnapshot::new(
567 "ix_orders_status",
568 vec![IndexColumnSnapshot::desc("status")],
569 false,
570 )],
571 foreign_keys: vec![ForeignKeySnapshot::new(
572 "fk_orders_customers",
573 vec!["customer_id".to_string()],
574 "sales",
575 "customers",
576 vec!["id".to_string()],
577 ReferentialAction::Cascade,
578 ReferentialAction::NoAction,
579 )],
580 }],
581 )]);
582
583 let json = snapshot.to_json_pretty().unwrap();
584 let parsed = ModelSnapshot::from_json(&json).unwrap();
585
586 assert_eq!(parsed, snapshot);
587 assert!(json.contains("\"sql_type\": \"custom:varchar(24)\""));
588 assert!(json.contains("\"on_delete\": \"cascade\""));
589 }
590}