1use super::*;
2
3pub async fn execute_schema_migration_plan<E>(
8 executor: &mut E,
9 plan: &SchemaMigrationPlan,
10) -> Result<(), RustAuthError>
11where
12 E: SqlExecutor,
13{
14 for statement in &plan.statements {
15 executor
16 .execute(SqlStatement::new(statement.sql.clone()))
17 .await?;
18 }
19 Ok(())
20}
21
22pub fn ensure_executable_migration_plan(plan: &SchemaMigrationPlan) -> Result<(), RustAuthError> {
28 if !plan.has_warnings() {
29 return Ok(());
30 }
31
32 Err(RustAuthError::Adapter(format!(
33 "migration contains {} non-executable migration warnings; inspect plan_migrations or compile_migrations before applying",
34 plan.warnings.len()
35 )))
36}
37
38#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
40pub struct SchemaMigrationPlan {
41 pub to_be_created: Vec<TableToCreate>,
42 pub to_be_added: Vec<ColumnToAdd>,
43 pub indexes_to_be_created: Vec<IndexToCreate>,
44 pub warnings: Vec<SchemaMigrationWarning>,
45 pub statements: Vec<MigrationStatement>,
46}
47
48impl SchemaMigrationPlan {
49 pub fn is_empty(&self) -> bool {
50 self.statements.is_empty()
51 }
52
53 pub fn has_warnings(&self) -> bool {
54 !self.warnings.is_empty()
55 }
56
57 pub fn compile(&self) -> String {
58 if self.statements.is_empty() {
59 return ";".to_owned();
60 }
61
62 format!(
63 "{};",
64 self.statements
65 .iter()
66 .map(|statement| statement.sql.as_str())
67 .collect::<Vec<_>>()
68 .join(";\n\n")
69 )
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct TableToCreate {
76 pub logical_name: String,
77 pub table_name: String,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct ColumnToAdd {
83 pub table_logical_name: String,
84 pub table_name: String,
85 pub field_logical_name: String,
86 pub column_name: String,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct IndexToCreate {
92 pub table_logical_name: String,
93 pub table_name: String,
94 pub field_logical_name: String,
95 pub column_name: String,
96 pub index_name: String,
97 pub unique: bool,
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
102#[allow(clippy::enum_variant_names)]
103pub enum SchemaMigrationWarning {
104 ColumnTypeMismatch {
105 table_name: String,
106 column_name: String,
107 expected: String,
108 actual: String,
109 },
110 ColumnNullabilityMismatch {
111 table_name: String,
112 column_name: String,
113 expected_nullable: bool,
114 actual_nullable: bool,
115 },
116 PrimaryKeyMismatch {
117 table_name: String,
118 column_name: String,
119 },
120 GeneratedIdMismatch {
121 table_name: String,
122 column_name: String,
123 expected: IdGeneration,
124 actual: Option<IdGeneration>,
125 },
126 ForeignKeyMismatch {
127 table_name: String,
128 column_name: String,
129 expected: ForeignKey,
130 actual: Option<ForeignKey>,
131 },
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct MigrationStatement {
137 pub kind: MigrationStatementKind,
138 pub sql: String,
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
143pub enum MigrationStatementKind {
144 CreateTable,
145 AddColumn,
146 CreateIndex,
147}
148
149#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
151pub struct SqlSchemaSnapshot {
152 tables: IndexMap<String, SqlTableSnapshot>,
153}
154
155impl SqlSchemaSnapshot {
156 pub fn with_table(mut self, table: impl Into<String>) -> Self {
157 self.tables.entry(table.into()).or_default();
158 self
159 }
160
161 pub fn with_column(mut self, table: impl Into<String>, column: SqlColumnSnapshot) -> Self {
162 self.tables
163 .entry(table.into())
164 .or_default()
165 .columns
166 .insert(column.name.clone(), column);
167 self
168 }
169
170 pub fn with_index(mut self, table: impl Into<String>, index: impl Into<String>) -> Self {
171 self.tables
172 .entry(table.into())
173 .or_default()
174 .indexes
175 .insert(index.into());
176 self
177 }
178
179 pub fn with_unique_column(
180 mut self,
181 table: impl Into<String>,
182 column: impl Into<String>,
183 ) -> Self {
184 self.tables
185 .entry(table.into())
186 .or_default()
187 .unique_columns
188 .insert(column.into());
189 self
190 }
191
192 pub fn table_exists(&self, table: &str) -> bool {
193 self.tables.contains_key(table)
194 }
195
196 pub fn column_type(&self, table: &str, column: &str) -> Option<&str> {
197 self.column(table, column)
198 .map(|column| column.data_type.as_str())
199 }
200
201 pub fn column(&self, table: &str, column: &str) -> Option<&SqlColumnSnapshot> {
202 self.tables
203 .get(table)
204 .and_then(|table| table.columns.get(column))
205 }
206
207 pub fn index_exists(&self, table: &str, index: &str) -> bool {
208 self.tables
209 .get(table)
210 .is_some_and(|table| table.indexes.contains(index))
211 || self
212 .tables
213 .values()
214 .any(|table| table.indexes.contains(index))
215 }
216
217 pub fn unique_column_exists(&self, table: &str, column: &str) -> bool {
218 self.tables
219 .get(table)
220 .is_some_and(|table| table.unique_columns.contains(column))
221 }
222}
223
224#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
226pub struct SqlTableSnapshot {
227 columns: IndexMap<String, SqlColumnSnapshot>,
228 indexes: IndexSet<String>,
229 unique_columns: IndexSet<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234pub struct SqlColumnSnapshot {
235 pub name: String,
236 pub data_type: String,
237 pub nullable: Option<bool>,
238 pub primary_key: Option<bool>,
239 pub generated_id: Option<IdGeneration>,
240 pub foreign_key: Option<ForeignKey>,
241}
242
243impl SqlColumnSnapshot {
244 pub fn new(name: impl Into<String>, data_type: impl Into<String>) -> Self {
245 Self {
246 name: name.into(),
247 data_type: data_type.into(),
248 nullable: None,
249 primary_key: None,
250 generated_id: None,
251 foreign_key: None,
252 }
253 }
254
255 pub fn nullable(mut self, nullable: bool) -> Self {
256 self.nullable = Some(nullable);
257 self
258 }
259
260 pub fn primary_key(mut self, primary_key: bool) -> Self {
261 self.primary_key = Some(primary_key);
262 self
263 }
264
265 pub fn generated_id(mut self, generated_id: Option<IdGeneration>) -> Self {
266 self.generated_id = generated_id;
267 self
268 }
269
270 pub fn references(mut self, foreign_key: ForeignKey) -> Self {
271 self.foreign_key = Some(foreign_key);
272 self
273 }
274}
275
276pub fn plan_schema_migration(
278 dialect: SqlDialect,
279 schema: &DbSchema,
280 snapshot: &SqlSchemaSnapshot,
281) -> Result<SchemaMigrationPlan, RustAuthError> {
282 let mut plan = SchemaMigrationPlan::default();
283 let mut tables = schema.tables().collect::<Vec<_>>();
284 tables.sort_by_key(|(_, table)| table.order.unwrap_or(u16::MAX));
285
286 for (table_logical_name, table) in &tables {
287 if snapshot.table_exists(&table.name) {
288 for (logical_name, field) in &table.fields {
289 if let Some(column) = snapshot.column(&table.name, &field.name) {
290 if !dialect.type_matches(&column.data_type, field) {
291 plan.warnings
292 .push(SchemaMigrationWarning::ColumnTypeMismatch {
293 table_name: table.name.clone(),
294 column_name: field.name.clone(),
295 expected: dialect.sql_type(logical_name, field),
296 actual: column.data_type.clone(),
297 });
298 }
299 push_constraint_warnings(&mut plan, table, logical_name, field, column);
300 } else {
301 plan.to_be_added.push(ColumnToAdd {
302 table_logical_name: (*table_logical_name).to_owned(),
303 table_name: table.name.clone(),
304 field_logical_name: logical_name.clone(),
305 column_name: field.name.clone(),
306 });
307 plan.statements.push(MigrationStatement {
308 kind: MigrationStatementKind::AddColumn,
309 sql: dialect.add_column_statement(&table.name, logical_name, field)?,
310 });
311 }
312 }
313 } else {
314 plan.to_be_created.push(TableToCreate {
315 logical_name: (*table_logical_name).to_owned(),
316 table_name: table.name.clone(),
317 });
318 plan.statements.push(MigrationStatement {
319 kind: MigrationStatementKind::CreateTable,
320 sql: dialect.create_table_statement(table)?,
321 });
322 }
323 }
324
325 for (table_logical_name, table) in tables {
326 let table_exists = snapshot.table_exists(&table.name);
327 for (logical_name, field) in &table.fields {
328 if field.index || field.unique {
329 if field.unique
330 && (!table_exists || snapshot.unique_column_exists(&table.name, &field.name))
331 {
332 continue;
333 }
334 let prefix = if field.unique { "uidx" } else { "idx" };
335 let index_name = dialect
336 .sanitize_identifier(&format!("{prefix}_{}_{}", table.name, logical_name))?;
337 if !snapshot.index_exists(&table.name, &index_name) {
338 plan.indexes_to_be_created.push(IndexToCreate {
339 table_logical_name: table_logical_name.to_owned(),
340 table_name: table.name.clone(),
341 field_logical_name: logical_name.clone(),
342 column_name: field.name.clone(),
343 index_name: index_name.clone(),
344 unique: field.unique,
345 });
346 plan.statements.push(MigrationStatement {
347 kind: MigrationStatementKind::CreateIndex,
348 sql: dialect.create_index_statement(
349 &table.name,
350 &field.name,
351 &index_name,
352 field.unique,
353 )?,
354 });
355 }
356 }
357 }
358 }
359
360 Ok(plan)
361}
362
363fn push_constraint_warnings(
364 plan: &mut SchemaMigrationPlan,
365 table: &DbTable,
366 logical_name: &str,
367 field: &DbField,
368 column: &SqlColumnSnapshot,
369) {
370 if logical_name == "id" || field.name == "id" {
371 if column.primary_key == Some(false) {
372 plan.warnings
373 .push(SchemaMigrationWarning::PrimaryKeyMismatch {
374 table_name: table.name.clone(),
375 column_name: field.name.clone(),
376 });
377 }
378 } else if let Some(actual_nullable) = column.nullable {
379 let expected_nullable = !field.required;
380 if expected_nullable != actual_nullable {
381 plan.warnings
382 .push(SchemaMigrationWarning::ColumnNullabilityMismatch {
383 table_name: table.name.clone(),
384 column_name: field.name.clone(),
385 expected_nullable,
386 actual_nullable,
387 });
388 }
389 }
390
391 if logical_name == "id" || field.name == "id" {
392 if let Some(expected) = field.generated_id {
393 if column.generated_id != Some(expected) {
394 plan.warnings
395 .push(SchemaMigrationWarning::GeneratedIdMismatch {
396 table_name: table.name.clone(),
397 column_name: field.name.clone(),
398 expected,
399 actual: column.generated_id,
400 });
401 }
402 }
403 }
404
405 if let Some(expected) = &field.foreign_key {
406 if column.foreign_key.as_ref() != Some(expected) {
407 plan.warnings
408 .push(SchemaMigrationWarning::ForeignKeyMismatch {
409 table_name: table.name.clone(),
410 column_name: field.name.clone(),
411 expected: expected.clone(),
412 actual: column.foreign_key.clone(),
413 });
414 }
415 }
416}