Skip to main content

rustauth_core/db/sql/
migrations.rs

1use super::*;
2
3/// Executes a pure migration plan through any SQL executor.
4///
5/// Introspection and transaction ownership stay in the adapter crate; this
6/// helper only runs the already-planned SQL statements in order.
7pub async fn execute_schema_migration_plan<E>(
8    executor: &mut E,
9    plan: &SchemaMigrationPlan,
10) -> Result<(), RustAuthError>
11where
12    E: SqlExecutor,
13{
14    for statement in &plan.statements {
15        executor
16            .execute(SqlStatement::new(statement.sql.clone()))
17            .await?;
18    }
19    Ok(())
20}
21
22/// Rejects migration plans that carry non-executable warnings before any
23/// schema mutation runs.
24///
25/// Shared preflight so every SQL adapter refuses warning/error plans
26/// identically instead of silently mutating the database.
27pub fn ensure_executable_migration_plan(plan: &SchemaMigrationPlan) -> Result<(), RustAuthError> {
28    if !plan.has_warnings() {
29        return Ok(());
30    }
31
32    Err(RustAuthError::Adapter(format!(
33        "migration contains {} non-executable migration warnings; inspect plan_migrations or compile_migrations before applying",
34        plan.warnings.len()
35    )))
36}
37
38/// Additive schema changes planned for a live database.
39#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
40pub struct SchemaMigrationPlan {
41    pub to_be_created: Vec<TableToCreate>,
42    pub to_be_added: Vec<ColumnToAdd>,
43    pub indexes_to_be_created: Vec<IndexToCreate>,
44    pub warnings: Vec<SchemaMigrationWarning>,
45    pub statements: Vec<MigrationStatement>,
46}
47
48impl SchemaMigrationPlan {
49    pub fn is_empty(&self) -> bool {
50        self.statements.is_empty()
51    }
52
53    pub fn has_warnings(&self) -> bool {
54        !self.warnings.is_empty()
55    }
56
57    pub fn compile(&self) -> String {
58        if self.statements.is_empty() {
59            return ";".to_owned();
60        }
61
62        format!(
63            "{};",
64            self.statements
65                .iter()
66                .map(|statement| statement.sql.as_str())
67                .collect::<Vec<_>>()
68                .join(";\n\n")
69        )
70    }
71}
72
73/// A table missing from the database and planned for creation.
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct TableToCreate {
76    pub logical_name: String,
77    pub table_name: String,
78}
79
80/// A column missing from an existing table and planned for additive creation.
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct ColumnToAdd {
83    pub table_logical_name: String,
84    pub table_name: String,
85    pub field_logical_name: String,
86    pub column_name: String,
87}
88
89/// A standalone index missing from the database and planned for creation.
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct IndexToCreate {
92    pub table_logical_name: String,
93    pub table_name: String,
94    pub field_logical_name: String,
95    pub column_name: String,
96    pub index_name: String,
97    pub unique: bool,
98}
99
100/// Non-executable findings discovered while planning migrations.
101#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
102#[allow(clippy::enum_variant_names)]
103pub enum SchemaMigrationWarning {
104    ColumnTypeMismatch {
105        table_name: String,
106        column_name: String,
107        expected: String,
108        actual: String,
109    },
110    ColumnNullabilityMismatch {
111        table_name: String,
112        column_name: String,
113        expected_nullable: bool,
114        actual_nullable: bool,
115    },
116    PrimaryKeyMismatch {
117        table_name: String,
118        column_name: String,
119    },
120    GeneratedIdMismatch {
121        table_name: String,
122        column_name: String,
123        expected: IdGeneration,
124        actual: Option<IdGeneration>,
125    },
126    ForeignKeyMismatch {
127        table_name: String,
128        column_name: String,
129        expected: ForeignKey,
130        actual: Option<ForeignKey>,
131    },
132}
133
134/// A SQL statement emitted by a migration plan.
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct MigrationStatement {
137    pub kind: MigrationStatementKind,
138    pub sql: String,
139}
140
141/// The additive operation represented by a migration statement.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
143pub enum MigrationStatementKind {
144    CreateTable,
145    AddColumn,
146    CreateIndex,
147}
148
149/// Introspected database schema used by the pure migration planner.
150#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
151pub struct SqlSchemaSnapshot {
152    tables: IndexMap<String, SqlTableSnapshot>,
153}
154
155impl SqlSchemaSnapshot {
156    pub fn with_table(mut self, table: impl Into<String>) -> Self {
157        self.tables.entry(table.into()).or_default();
158        self
159    }
160
161    pub fn with_column(mut self, table: impl Into<String>, column: SqlColumnSnapshot) -> Self {
162        self.tables
163            .entry(table.into())
164            .or_default()
165            .columns
166            .insert(column.name.clone(), column);
167        self
168    }
169
170    pub fn with_index(mut self, table: impl Into<String>, index: impl Into<String>) -> Self {
171        self.tables
172            .entry(table.into())
173            .or_default()
174            .indexes
175            .insert(index.into());
176        self
177    }
178
179    pub fn with_unique_column(
180        mut self,
181        table: impl Into<String>,
182        column: impl Into<String>,
183    ) -> Self {
184        self.tables
185            .entry(table.into())
186            .or_default()
187            .unique_columns
188            .insert(column.into());
189        self
190    }
191
192    pub fn table_exists(&self, table: &str) -> bool {
193        self.tables.contains_key(table)
194    }
195
196    pub fn column_type(&self, table: &str, column: &str) -> Option<&str> {
197        self.column(table, column)
198            .map(|column| column.data_type.as_str())
199    }
200
201    pub fn column(&self, table: &str, column: &str) -> Option<&SqlColumnSnapshot> {
202        self.tables
203            .get(table)
204            .and_then(|table| table.columns.get(column))
205    }
206
207    pub fn index_exists(&self, table: &str, index: &str) -> bool {
208        self.tables
209            .get(table)
210            .is_some_and(|table| table.indexes.contains(index))
211            || self
212                .tables
213                .values()
214                .any(|table| table.indexes.contains(index))
215    }
216
217    pub fn unique_column_exists(&self, table: &str, column: &str) -> bool {
218        self.tables
219            .get(table)
220            .is_some_and(|table| table.unique_columns.contains(column))
221    }
222}
223
224/// Introspected table metadata used by the pure migration planner.
225#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
226pub struct SqlTableSnapshot {
227    columns: IndexMap<String, SqlColumnSnapshot>,
228    indexes: IndexSet<String>,
229    unique_columns: IndexSet<String>,
230}
231
232/// Introspected column metadata used by the pure migration planner.
233#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234pub struct SqlColumnSnapshot {
235    pub name: String,
236    pub data_type: String,
237    pub nullable: Option<bool>,
238    pub primary_key: Option<bool>,
239    pub generated_id: Option<IdGeneration>,
240    pub foreign_key: Option<ForeignKey>,
241}
242
243impl SqlColumnSnapshot {
244    pub fn new(name: impl Into<String>, data_type: impl Into<String>) -> Self {
245        Self {
246            name: name.into(),
247            data_type: data_type.into(),
248            nullable: None,
249            primary_key: None,
250            generated_id: None,
251            foreign_key: None,
252        }
253    }
254
255    pub fn nullable(mut self, nullable: bool) -> Self {
256        self.nullable = Some(nullable);
257        self
258    }
259
260    pub fn primary_key(mut self, primary_key: bool) -> Self {
261        self.primary_key = Some(primary_key);
262        self
263    }
264
265    pub fn generated_id(mut self, generated_id: Option<IdGeneration>) -> Self {
266        self.generated_id = generated_id;
267        self
268    }
269
270    pub fn references(mut self, foreign_key: ForeignKey) -> Self {
271        self.foreign_key = Some(foreign_key);
272        self
273    }
274}
275
276/// Compares a target RustAuth schema with a SQL schema snapshot and emits an additive plan.
277pub fn plan_schema_migration(
278    dialect: SqlDialect,
279    schema: &DbSchema,
280    snapshot: &SqlSchemaSnapshot,
281) -> Result<SchemaMigrationPlan, RustAuthError> {
282    let mut plan = SchemaMigrationPlan::default();
283    let mut tables = schema.tables().collect::<Vec<_>>();
284    tables.sort_by_key(|(_, table)| table.order.unwrap_or(u16::MAX));
285
286    for (table_logical_name, table) in &tables {
287        if snapshot.table_exists(&table.name) {
288            for (logical_name, field) in &table.fields {
289                if let Some(column) = snapshot.column(&table.name, &field.name) {
290                    if !dialect.type_matches(&column.data_type, field) {
291                        plan.warnings
292                            .push(SchemaMigrationWarning::ColumnTypeMismatch {
293                                table_name: table.name.clone(),
294                                column_name: field.name.clone(),
295                                expected: dialect.sql_type(logical_name, field),
296                                actual: column.data_type.clone(),
297                            });
298                    }
299                    push_constraint_warnings(&mut plan, table, logical_name, field, column);
300                } else {
301                    plan.to_be_added.push(ColumnToAdd {
302                        table_logical_name: (*table_logical_name).to_owned(),
303                        table_name: table.name.clone(),
304                        field_logical_name: logical_name.clone(),
305                        column_name: field.name.clone(),
306                    });
307                    plan.statements.push(MigrationStatement {
308                        kind: MigrationStatementKind::AddColumn,
309                        sql: dialect.add_column_statement(&table.name, logical_name, field)?,
310                    });
311                }
312            }
313        } else {
314            plan.to_be_created.push(TableToCreate {
315                logical_name: (*table_logical_name).to_owned(),
316                table_name: table.name.clone(),
317            });
318            plan.statements.push(MigrationStatement {
319                kind: MigrationStatementKind::CreateTable,
320                sql: dialect.create_table_statement(table)?,
321            });
322        }
323    }
324
325    for (table_logical_name, table) in tables {
326        let table_exists = snapshot.table_exists(&table.name);
327        for (logical_name, field) in &table.fields {
328            if field.index || field.unique {
329                if field.unique
330                    && (!table_exists || snapshot.unique_column_exists(&table.name, &field.name))
331                {
332                    continue;
333                }
334                let prefix = if field.unique { "uidx" } else { "idx" };
335                let index_name = dialect
336                    .sanitize_identifier(&format!("{prefix}_{}_{}", table.name, logical_name))?;
337                if !snapshot.index_exists(&table.name, &index_name) {
338                    plan.indexes_to_be_created.push(IndexToCreate {
339                        table_logical_name: table_logical_name.to_owned(),
340                        table_name: table.name.clone(),
341                        field_logical_name: logical_name.clone(),
342                        column_name: field.name.clone(),
343                        index_name: index_name.clone(),
344                        unique: field.unique,
345                    });
346                    plan.statements.push(MigrationStatement {
347                        kind: MigrationStatementKind::CreateIndex,
348                        sql: dialect.create_index_statement(
349                            &table.name,
350                            &field.name,
351                            &index_name,
352                            field.unique,
353                        )?,
354                    });
355                }
356            }
357        }
358    }
359
360    Ok(plan)
361}
362
363fn push_constraint_warnings(
364    plan: &mut SchemaMigrationPlan,
365    table: &DbTable,
366    logical_name: &str,
367    field: &DbField,
368    column: &SqlColumnSnapshot,
369) {
370    if logical_name == "id" || field.name == "id" {
371        if column.primary_key == Some(false) {
372            plan.warnings
373                .push(SchemaMigrationWarning::PrimaryKeyMismatch {
374                    table_name: table.name.clone(),
375                    column_name: field.name.clone(),
376                });
377        }
378    } else if let Some(actual_nullable) = column.nullable {
379        let expected_nullable = !field.required;
380        if expected_nullable != actual_nullable {
381            plan.warnings
382                .push(SchemaMigrationWarning::ColumnNullabilityMismatch {
383                    table_name: table.name.clone(),
384                    column_name: field.name.clone(),
385                    expected_nullable,
386                    actual_nullable,
387                });
388        }
389    }
390
391    if logical_name == "id" || field.name == "id" {
392        if let Some(expected) = field.generated_id {
393            if column.generated_id != Some(expected) {
394                plan.warnings
395                    .push(SchemaMigrationWarning::GeneratedIdMismatch {
396                        table_name: table.name.clone(),
397                        column_name: field.name.clone(),
398                        expected,
399                        actual: column.generated_id,
400                    });
401            }
402        }
403    }
404
405    if let Some(expected) = &field.foreign_key {
406        if column.foreign_key.as_ref() != Some(expected) {
407            plan.warnings
408                .push(SchemaMigrationWarning::ForeignKeyMismatch {
409                    table_name: table.name.clone(),
410                    column_name: field.name.clone(),
411                    expected: expected.clone(),
412                    actual: column.foreign_key.clone(),
413                });
414        }
415    }
416}