1use super::{Schema, TopologicalSort};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4 ForeignKeyCreateStatement, IndexCreateStatement, TableAlterStatement, TableCreateStatement,
5 TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8pub struct SchemaBuilder {
10 helper: Schema,
11 entities: Vec<EntitySchemaInfo>,
12}
13
14pub struct EntitySchemaInfo {
16 table: TableCreateStatement,
17 enums: Vec<TypeCreateStatement>,
18 indexes: Vec<IndexCreateStatement>,
19}
20
21impl std::fmt::Debug for SchemaBuilder {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "SchemaBuilder {{")?;
24 write!(f, " entities: [")?;
25 for (i, entity) in self.entities.iter().enumerate() {
26 if i > 0 {
27 write!(f, ", ")?;
28 }
29 entity.debug_print(f, &self.helper.backend)?;
30 }
31 write!(f, " ]")?;
32 write!(f, " }}")
33 }
34}
35
36impl std::fmt::Debug for EntitySchemaInfo {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 self.debug_print(f, &DbBackend::Sqlite)
39 }
40}
41
42impl SchemaBuilder {
43 pub fn new(schema: Schema) -> Self {
45 Self {
46 helper: schema,
47 entities: Default::default(),
48 }
49 }
50
51 pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
53 let entity = EntitySchemaInfo::new(entity, &self.helper);
54 if !self
55 .entities
56 .iter()
57 .any(|e| e.table.get_table_name() == entity.table.get_table_name())
58 {
59 self.entities.push(entity);
60 }
61 self
62 }
63
64 #[cfg(feature = "entity-registry")]
65 pub(crate) fn helper(&self) -> &Schema {
66 &self.helper
67 }
68
69 #[cfg(feature = "entity-registry")]
70 pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
71 self.entities.push(entity);
72 }
73
74 #[cfg(feature = "schema-sync")]
77 #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
78 pub fn sync<C>(self, db: &C) -> Result<(), DbErr>
79 where
80 C: ConnectionTrait + sea_schema::Connection,
81 {
82 let _existing = match db.get_database_backend() {
83 #[cfg(feature = "sqlx-mysql")]
84 DbBackend::MySql => {
85 use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
86
87 let current_schema: String = db
88 .query_one(
89 sea_query::SelectStatement::new()
90 .expr(sea_schema::mysql::MySql::get_current_schema()),
91 )?
92 .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
93 .try_get_by_index(0)?;
94 let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema);
95
96 let schema = schema_discovery
97 .discover_with(db)
98 .map_err(|err| DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}"))))?;
99
100 DiscoveredSchema {
101 tables: schema.tables.iter().map(|table| table.write()).collect(),
102 enums: vec![],
103 }
104 }
105 #[cfg(feature = "sqlx-postgres")]
106 DbBackend::Postgres => {
107 use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
108
109 let current_schema: String = db
110 .query_one(
111 sea_query::SelectStatement::new()
112 .expr(sea_schema::postgres::Postgres::get_current_schema()),
113 )?
114 .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
115 .try_get_by_index(0)?;
116 let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema);
117
118 let schema = schema_discovery
119 .discover_with(db)
120 .map_err(|err| DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}"))))?;
121
122 DiscoveredSchema {
123 tables: schema.tables.iter().map(|table| table.write()).collect(),
124 enums: schema.enums.iter().map(|def| def.write()).collect(),
125 }
126 }
127 #[cfg(feature = "sqlx-sqlite")]
128 DbBackend::Sqlite => {
129 use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
130 let schema = SchemaDiscovery::discover_with(db)
131 .map_err(|err| {
132 DbErr::Query(match err {
133 SqliteDiscoveryError::SqlxError(err) => {
134 crate::RuntimeErr::SqlxError(err.into())
135 }
136 _ => crate::RuntimeErr::Internal(format!("{err:?}")),
137 })
138 })?
139 .merge_indexes_into_table();
140 DiscoveredSchema {
141 tables: schema.tables.iter().map(|table| table.write()).collect(),
142 enums: vec![],
143 }
144 }
145 #[cfg(feature = "rusqlite")]
146 DbBackend::Sqlite => {
147 use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
148 let schema = SchemaDiscovery::discover_with(db)
149 .map_err(|err| {
150 DbErr::Query(match err {
151 SqliteDiscoveryError::RusqliteError(err) => {
152 crate::RuntimeErr::Rusqlite(err.into())
153 }
154 _ => crate::RuntimeErr::Internal(format!("{err:?}")),
155 })
156 })?
157 .merge_indexes_into_table();
158 DiscoveredSchema {
159 tables: schema.tables.iter().map(|table| table.write()).collect(),
160 enums: vec![],
161 }
162 }
163 #[allow(unreachable_patterns)]
164 other => {
165 return Err(DbErr::BackendNotSupported {
166 db: other.as_str(),
167 ctx: "SchemaBuilder::sync",
168 });
169 }
170 };
171
172 #[allow(unreachable_code)]
173 let mut created_enums: Vec<Statement> = Default::default();
174
175 #[allow(unreachable_code)]
176 for table_name in self.sorted_tables() {
177 if let Some(entity) = self
178 .entities
179 .iter()
180 .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
181 {
182 entity.sync(db, &_existing, &mut created_enums)?;
183 }
184 }
185
186 Ok(())
187 }
188
189 pub fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
192 let mut created_enums: Vec<Statement> = Default::default();
193
194 for table_name in self.sorted_tables() {
195 if let Some(entity) = self
196 .entities
197 .iter()
198 .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
199 {
200 entity.apply(db, &mut created_enums)?;
201 }
202 }
203
204 Ok(())
205 }
206
207 fn sorted_tables(&self) -> Vec<TableName> {
208 let mut sorter = TopologicalSort::<TableName>::new();
209
210 for entity in self.entities.iter() {
211 let table_name = get_table_name(entity.table.get_table_name());
212 sorter.insert(table_name);
213 }
214 for entity in self.entities.iter() {
215 let self_table = get_table_name(entity.table.get_table_name());
216 for fk in entity.table.get_foreign_key_create_stmts().iter() {
217 let fk = fk.get_foreign_key();
218 let ref_table = get_table_name(fk.get_ref_table());
219 if self_table != ref_table {
220 sorter.add_dependency(ref_table, self_table.clone());
222 }
223 }
224 }
225 let mut sorted = Vec::new();
226 while let Some(i) = sorter.pop() {
227 sorted.push(i);
228 }
229 if sorted.len() != self.entities.len() {
230 for entity in self.entities.iter() {
232 let table_name = get_table_name(entity.table.get_table_name());
233 if !sorted.contains(&table_name) {
234 sorted.push(table_name);
235 }
236 }
237 }
238
239 sorted
240 }
241}
242
243struct DiscoveredSchema {
244 tables: Vec<TableCreateStatement>,
245 enums: Vec<TypeCreateStatement>,
246}
247
248impl EntitySchemaInfo {
249 pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
251 Self {
252 table: helper.create_table_from_entity(entity),
253 enums: helper.create_enum_from_entity(entity),
254 indexes: helper.create_index_from_entity(entity),
255 }
256 }
257
258 fn apply<C: ConnectionTrait>(
259 &self,
260 db: &C,
261 created_enums: &mut Vec<Statement>,
262 ) -> Result<(), DbErr> {
263 for stmt in self.enums.iter() {
264 let new_stmt = db.get_database_backend().build(stmt);
265 if !created_enums.iter().any(|s| s == &new_stmt) {
266 db.execute(stmt)?;
267 created_enums.push(new_stmt);
268 }
269 }
270 db.execute(&self.table)?;
271 for stmt in self.indexes.iter() {
272 db.execute(stmt)?;
273 }
274 Ok(())
275 }
276
277 #[allow(dead_code)]
279 fn sync<C: ConnectionTrait>(
280 &self,
281 db: &C,
282 existing: &DiscoveredSchema,
283 created_enums: &mut Vec<Statement>,
284 ) -> Result<(), DbErr> {
285 let db_backend = db.get_database_backend();
286
287 for stmt in self.enums.iter() {
289 let mut has_enum = false;
290 let new_stmt = db_backend.build(stmt);
291 for exsiting_enum in &existing.enums {
292 if db_backend.build(exsiting_enum) == new_stmt {
293 has_enum = true;
294 break;
296 }
297 }
298 if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
299 db.execute(stmt)?;
300 created_enums.push(new_stmt);
301 }
302 }
303 let table_name = get_table_name(self.table.get_table_name());
304 let mut existing_table = None;
305 for tbl in &existing.tables {
306 if get_table_name(tbl.get_table_name()) == table_name {
307 existing_table = Some(tbl);
308 break;
309 }
310 }
311 if let Some(existing_table) = existing_table {
312 for column_def in self.table.get_columns() {
313 let mut column_exists = false;
314 for existing_column in existing_table.get_columns() {
315 if column_def.get_column_name() == existing_column.get_column_name() {
316 column_exists = true;
317 break;
318 }
319 }
320 if !column_exists {
321 let mut renamed_from = "";
322 if let Some(comment) = &column_def.get_column_spec().comment {
323 if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") {
324 if let Some((prefix, _)) = suffix.split_once('"') {
325 renamed_from = prefix;
326 }
327 }
328 }
329 if renamed_from.is_empty() {
330 db.execute(
331 TableAlterStatement::new()
332 .table(self.table.get_table_name().expect("Checked above").clone())
333 .add_column(column_def.to_owned()),
334 )?;
335 } else {
336 db.execute(
337 TableAlterStatement::new()
338 .table(self.table.get_table_name().expect("Checked above").clone())
339 .rename_column(
340 renamed_from.to_owned(),
341 column_def.get_column_name(),
342 ),
343 )?;
344 }
345 }
346 }
347 if db.get_database_backend() != DbBackend::Sqlite {
348 for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
349 let mut key_exists = false;
350 for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
351 if compare_foreign_key(foreign_key, existing_key) {
352 key_exists = true;
353 break;
354 }
355 }
356 if !key_exists {
357 db.execute(foreign_key)?;
358 }
359 }
360 }
361 } else {
362 db.execute(&self.table)?;
363 }
364 for stmt in self.indexes.iter() {
365 let mut has_index = false;
366 if let Some(existing_table) = existing_table {
367 for exsiting_index in existing_table.get_indexes() {
368 if exsiting_index.get_index_spec().get_column_names()
369 == stmt.get_index_spec().get_column_names()
370 {
371 has_index = true;
372 break;
373 }
374 }
375 }
376 if !has_index {
377 let mut stmt = stmt.clone();
379 stmt.if_not_exists();
380 db.execute(&stmt)?;
381 }
382 }
383 if let Some(existing_table) = existing_table {
384 for exsiting_index in existing_table.get_indexes() {
387 if exsiting_index.is_unique_key() {
388 let mut has_index = false;
389 for stmt in self.indexes.iter() {
390 if exsiting_index.get_index_spec().get_column_names()
391 == stmt.get_index_spec().get_column_names()
392 {
393 has_index = true;
394 break;
395 }
396 }
397 if !has_index {
398 if let Some(drop_existing) = exsiting_index.get_index_spec().get_name() {
399 db.execute(sea_query::Index::drop().name(drop_existing))?;
400 }
401 }
402 }
403 }
404 }
405 Ok(())
406 }
407
408 fn debug_print(
409 &self,
410 f: &mut std::fmt::Formatter<'_>,
411 backend: &DbBackend,
412 ) -> std::fmt::Result {
413 write!(f, "EntitySchemaInfo {{")?;
414 write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
415 write!(f, " enums: [")?;
416 for (i, stmt) in self.enums.iter().enumerate() {
417 if i > 0 {
418 write!(f, ", ")?;
419 }
420 write!(f, "{:?}", backend.build(stmt).to_string())?;
421 }
422 write!(f, " ]")?;
423 write!(f, " indexes: [")?;
424 for (i, stmt) in self.indexes.iter().enumerate() {
425 if i > 0 {
426 write!(f, ", ")?;
427 }
428 write!(f, "{:?}", backend.build(stmt).to_string())?;
429 }
430 write!(f, " ]")?;
431 write!(f, " }}")
432 }
433}
434
435fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
436 match table_ref {
437 Some(TableRef::Table(table_name, _)) => table_name.clone(),
438 None => panic!("Expect TableCreateStatement is properly built"),
439 _ => unreachable!("Unexpected {table_ref:?}"),
440 }
441}
442
443fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
444 let a = a.get_foreign_key();
445 let b = b.get_foreign_key();
446
447 a.get_name() == b.get_name()
448 || (a.get_ref_table() == b.get_ref_table()
449 && a.get_columns() == b.get_columns()
450 && a.get_ref_columns() == b.get_ref_columns())
451}