1use super::{Schema, TopologicalSort};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4 ForeignKeyCreateStatement, IndexCreateStatement, TableAlterStatement, TableCreateStatement,
5 TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8pub struct SchemaBuilder {
10 helper: Schema,
11 entities: Vec<EntitySchemaInfo>,
12}
13
14pub struct EntitySchemaInfo {
16 table: TableCreateStatement,
17 enums: Vec<TypeCreateStatement>,
18 indexes: Vec<IndexCreateStatement>,
19}
20
21impl std::fmt::Debug for SchemaBuilder {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "SchemaBuilder {{")?;
24 write!(f, " entities: [")?;
25 for (i, entity) in self.entities.iter().enumerate() {
26 if i > 0 {
27 write!(f, ", ")?;
28 }
29 entity.debug_print(f, &self.helper.backend)?;
30 }
31 write!(f, " ]")?;
32 write!(f, " }}")
33 }
34}
35
36impl std::fmt::Debug for EntitySchemaInfo {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 self.debug_print(f, &DbBackend::Sqlite)
39 }
40}
41
42impl SchemaBuilder {
43 pub fn new(schema: Schema) -> Self {
45 Self {
46 helper: schema,
47 entities: Default::default(),
48 }
49 }
50
51 pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
53 let entity = EntitySchemaInfo::new(entity, &self.helper);
54 if !self
55 .entities
56 .iter()
57 .any(|e| e.table.get_table_name() == entity.table.get_table_name())
58 {
59 self.entities.push(entity);
60 }
61 self
62 }
63
64 #[cfg(feature = "entity-registry")]
65 pub(crate) fn helper(&self) -> &Schema {
66 &self.helper
67 }
68
69 #[cfg(feature = "entity-registry")]
70 pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
71 self.entities.push(entity);
72 }
73
74 #[cfg(feature = "schema-sync")]
77 #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
78 pub async fn sync<C>(self, db: &C) -> Result<(), DbErr>
79 where
80 C: ConnectionTrait + sea_schema::Connection,
81 {
82 let _existing =
83 match db.get_database_backend() {
84 #[cfg(feature = "sqlx-mysql")]
85 DbBackend::MySql => {
86 use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
87
88 let current_schema: String = db
89 .query_one(
90 sea_query::SelectStatement::new()
91 .expr(sea_schema::mysql::MySql::get_current_schema()),
92 )
93 .await?
94 .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
95 .try_get_by_index(0)?;
96 let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema);
97
98 let schema = schema_discovery.discover_with(db).await.map_err(|err| {
99 DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}")))
100 })?;
101
102 DiscoveredSchema {
103 tables: schema.tables.iter().map(|table| table.write()).collect(),
104 enums: vec![],
105 }
106 }
107 #[cfg(feature = "sqlx-postgres")]
108 DbBackend::Postgres => {
109 use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
110
111 let current_schema: String = db
112 .query_one(
113 sea_query::SelectStatement::new()
114 .expr(sea_schema::postgres::Postgres::get_current_schema()),
115 )
116 .await?
117 .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
118 .try_get_by_index(0)?;
119 let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema);
120
121 let schema = schema_discovery.discover_with(db).await.map_err(|err| {
122 DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}")))
123 })?;
124
125 DiscoveredSchema {
126 tables: schema.tables.iter().map(|table| table.write()).collect(),
127 enums: schema.enums.iter().map(|def| def.write()).collect(),
128 }
129 }
130 #[cfg(feature = "sqlx-sqlite")]
131 DbBackend::Sqlite => {
132 use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
133 let schema = SchemaDiscovery::discover_with(db)
134 .await
135 .map_err(|err| {
136 DbErr::Query(match err {
137 SqliteDiscoveryError::SqlxError(err) => {
138 crate::RuntimeErr::SqlxError(err.into())
139 }
140 _ => crate::RuntimeErr::Internal(format!("{err:?}")),
141 })
142 })?
143 .merge_indexes_into_table();
144 DiscoveredSchema {
145 tables: schema.tables.iter().map(|table| table.write()).collect(),
146 enums: vec![],
147 }
148 }
149 #[cfg(feature = "rusqlite")]
150 DbBackend::Sqlite => {
151 use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
152 let schema = SchemaDiscovery::discover_with(db)
153 .map_err(|err| {
154 DbErr::Query(match err {
155 SqliteDiscoveryError::RusqliteError(err) => {
156 crate::RuntimeErr::Rusqlite(err.into())
157 }
158 _ => crate::RuntimeErr::Internal(format!("{err:?}")),
159 })
160 })?
161 .merge_indexes_into_table();
162 DiscoveredSchema {
163 tables: schema.tables.iter().map(|table| table.write()).collect(),
164 enums: vec![],
165 }
166 }
167 #[allow(unreachable_patterns)]
168 other => {
169 return Err(DbErr::BackendNotSupported {
170 db: other.as_str(),
171 ctx: "SchemaBuilder::sync",
172 });
173 }
174 };
175
176 #[allow(unreachable_code)]
177 let mut created_enums: Vec<Statement> = Default::default();
178
179 #[allow(unreachable_code)]
180 for table_name in self.sorted_tables() {
181 if let Some(entity) = self
182 .entities
183 .iter()
184 .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
185 {
186 entity.sync(db, &_existing, &mut created_enums).await?;
187 }
188 }
189
190 Ok(())
191 }
192
193 pub async fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
196 let mut created_enums: Vec<Statement> = Default::default();
197
198 for table_name in self.sorted_tables() {
199 if let Some(entity) = self
200 .entities
201 .iter()
202 .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
203 {
204 entity.apply(db, &mut created_enums).await?;
205 }
206 }
207
208 Ok(())
209 }
210
211 fn sorted_tables(&self) -> Vec<TableName> {
212 let mut sorter = TopologicalSort::<TableName>::new();
213
214 for entity in self.entities.iter() {
215 let table_name = get_table_name(entity.table.get_table_name());
216 sorter.insert(table_name);
217 }
218 for entity in self.entities.iter() {
219 let self_table = get_table_name(entity.table.get_table_name());
220 for fk in entity.table.get_foreign_key_create_stmts().iter() {
221 let fk = fk.get_foreign_key();
222 let ref_table = get_table_name(fk.get_ref_table());
223 if self_table != ref_table {
224 sorter.add_dependency(ref_table, self_table.clone());
226 }
227 }
228 }
229 let mut sorted = Vec::new();
230 while let Some(i) = sorter.pop() {
231 sorted.push(i);
232 }
233 if sorted.len() != self.entities.len() {
234 for entity in self.entities.iter() {
236 let table_name = get_table_name(entity.table.get_table_name());
237 if !sorted.contains(&table_name) {
238 sorted.push(table_name);
239 }
240 }
241 }
242
243 sorted
244 }
245}
246
247struct DiscoveredSchema {
248 tables: Vec<TableCreateStatement>,
249 enums: Vec<TypeCreateStatement>,
250}
251
252impl EntitySchemaInfo {
253 pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
255 Self {
256 table: helper.create_table_from_entity(entity),
257 enums: helper.create_enum_from_entity(entity),
258 indexes: helper.create_index_from_entity(entity),
259 }
260 }
261
262 async fn apply<C: ConnectionTrait>(
263 &self,
264 db: &C,
265 created_enums: &mut Vec<Statement>,
266 ) -> Result<(), DbErr> {
267 for stmt in self.enums.iter() {
268 let new_stmt = db.get_database_backend().build(stmt);
269 if !created_enums.iter().any(|s| s == &new_stmt) {
270 db.execute(stmt).await?;
271 created_enums.push(new_stmt);
272 }
273 }
274 db.execute(&self.table).await?;
275 for stmt in self.indexes.iter() {
276 db.execute(stmt).await?;
277 }
278 Ok(())
279 }
280
281 #[allow(dead_code)]
283 async fn sync<C: ConnectionTrait>(
284 &self,
285 db: &C,
286 existing: &DiscoveredSchema,
287 created_enums: &mut Vec<Statement>,
288 ) -> Result<(), DbErr> {
289 let db_backend = db.get_database_backend();
290
291 for stmt in self.enums.iter() {
293 let mut has_enum = false;
294 let new_stmt = db_backend.build(stmt);
295 for exsiting_enum in &existing.enums {
296 if db_backend.build(exsiting_enum) == new_stmt {
297 has_enum = true;
298 break;
300 }
301 }
302 if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
303 db.execute(stmt).await?;
304 created_enums.push(new_stmt);
305 }
306 }
307 let table_name = get_table_name(self.table.get_table_name());
308 let mut existing_table = None;
309 for tbl in &existing.tables {
310 if get_table_name(tbl.get_table_name()) == table_name {
311 existing_table = Some(tbl);
312 break;
313 }
314 }
315 if let Some(existing_table) = existing_table {
316 for column_def in self.table.get_columns() {
317 let mut column_exists = false;
318 for existing_column in existing_table.get_columns() {
319 if column_def.get_column_name() == existing_column.get_column_name() {
320 column_exists = true;
321 break;
322 }
323 }
324 if !column_exists {
325 let mut renamed_from = "";
326 if let Some(comment) = &column_def.get_column_spec().comment {
327 if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") {
328 if let Some((prefix, _)) = suffix.split_once('"') {
329 renamed_from = prefix;
330 }
331 }
332 }
333 if renamed_from.is_empty() {
334 db.execute(
335 TableAlterStatement::new()
336 .table(self.table.get_table_name().expect("Checked above").clone())
337 .add_column(column_def.to_owned()),
338 )
339 .await?;
340 } else {
341 db.execute(
342 TableAlterStatement::new()
343 .table(self.table.get_table_name().expect("Checked above").clone())
344 .rename_column(
345 renamed_from.to_owned(),
346 column_def.get_column_name(),
347 ),
348 )
349 .await?;
350 }
351 }
352 }
353 if db.get_database_backend() != DbBackend::Sqlite {
354 for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
355 let mut key_exists = false;
356 for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
357 if compare_foreign_key(foreign_key, existing_key) {
358 key_exists = true;
359 break;
360 }
361 }
362 if !key_exists {
363 db.execute(foreign_key).await?;
364 }
365 }
366 }
367 } else {
368 db.execute(&self.table).await?;
369 }
370 for stmt in self.indexes.iter() {
371 let mut has_index = false;
372 if let Some(existing_table) = existing_table {
373 for exsiting_index in existing_table.get_indexes() {
374 if exsiting_index.get_index_spec().get_column_names()
375 == stmt.get_index_spec().get_column_names()
376 {
377 has_index = true;
378 break;
379 }
380 }
381 }
382 if !has_index {
383 let mut stmt = stmt.clone();
385 stmt.if_not_exists();
386 db.execute(&stmt).await?;
387 }
388 }
389 if let Some(existing_table) = existing_table {
390 for exsiting_index in existing_table.get_indexes() {
393 if exsiting_index.is_unique_key() {
394 let mut has_index = false;
395 for stmt in self.indexes.iter() {
396 if exsiting_index.get_index_spec().get_column_names()
397 == stmt.get_index_spec().get_column_names()
398 {
399 has_index = true;
400 break;
401 }
402 }
403 if !has_index {
404 if let Some(drop_existing) = exsiting_index.get_index_spec().get_name() {
405 db.execute(sea_query::Index::drop().name(drop_existing))
406 .await?;
407 }
408 }
409 }
410 }
411 }
412 Ok(())
413 }
414
415 fn debug_print(
416 &self,
417 f: &mut std::fmt::Formatter<'_>,
418 backend: &DbBackend,
419 ) -> std::fmt::Result {
420 write!(f, "EntitySchemaInfo {{")?;
421 write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
422 write!(f, " enums: [")?;
423 for (i, stmt) in self.enums.iter().enumerate() {
424 if i > 0 {
425 write!(f, ", ")?;
426 }
427 write!(f, "{:?}", backend.build(stmt).to_string())?;
428 }
429 write!(f, " ]")?;
430 write!(f, " indexes: [")?;
431 for (i, stmt) in self.indexes.iter().enumerate() {
432 if i > 0 {
433 write!(f, ", ")?;
434 }
435 write!(f, "{:?}", backend.build(stmt).to_string())?;
436 }
437 write!(f, " ]")?;
438 write!(f, " }}")
439 }
440}
441
442fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
443 match table_ref {
444 Some(TableRef::Table(table_name, _)) => table_name.clone(),
445 None => panic!("Expect TableCreateStatement is properly built"),
446 _ => unreachable!("Unexpected {table_ref:?}"),
447 }
448}
449
450fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
451 let a = a.get_foreign_key();
452 let b = b.get_foreign_key();
453
454 a.get_name() == b.get_name()
455 || (a.get_ref_table() == b.get_ref_table()
456 && a.get_columns() == b.get_columns()
457 && a.get_ref_columns() == b.get_ref_columns())
458}