Skip to main content

yauth_migration/
diff.rs

1//! Schema diff engine.
2//!
3//! Compares two `YAuthSchema` snapshots (previous plugins vs current)
4//! and produces incremental SQL operations.
5
6use crate::collector::YAuthSchema;
7use crate::mysql::{mysql_default, mysql_type};
8use crate::postgres::pg_type;
9use crate::sqlite::{sqlite_default, sqlite_type};
10use crate::types::TableDef;
11
12/// A single schema change operation.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum SchemaChange {
15    /// A new table needs to be created (includes all columns).
16    CreateTable(TableDef),
17    /// An existing table needs to be dropped.
18    DropTable(TableDef),
19    /// A new column needs to be added to an existing table.
20    AddColumn {
21        table_name: String,
22        column: crate::types::ColumnDef,
23    },
24    /// A column needs to be removed from an existing table.
25    DropColumn {
26        table_name: String,
27        column_name: String,
28    },
29}
30
31/// Compute the diff between two schemas.
32///
33/// `from` is the previous schema state, `to` is the desired state.
34/// Returns a list of changes needed to go from `from` to `to`.
35pub fn schema_diff(from: &YAuthSchema, to: &YAuthSchema) -> Vec<SchemaChange> {
36    let mut changes = Vec::new();
37
38    let from_tables: std::collections::HashMap<&str, &TableDef> =
39        from.tables.iter().map(|t| (t.name.as_str(), t)).collect();
40    let to_tables: std::collections::HashMap<&str, &TableDef> =
41        to.tables.iter().map(|t| (t.name.as_str(), t)).collect();
42
43    // Tables to create (in `to` but not in `from`) -- preserve topological order
44    for table in &to.tables {
45        if !from_tables.contains_key(table.name.as_str()) {
46            changes.push(SchemaChange::CreateTable(table.clone()));
47        }
48    }
49
50    // Tables to drop (in `from` but not in `to`) -- reverse topological order
51    for table in from.tables.iter().rev() {
52        if !to_tables.contains_key(table.name.as_str()) {
53            changes.push(SchemaChange::DropTable(table.clone()));
54        }
55    }
56
57    // Column-level changes for tables that exist in both
58    for table in &to.tables {
59        if let Some(from_table) = from_tables.get(table.name.as_str()) {
60            let from_cols: std::collections::HashSet<&str> =
61                from_table.columns.iter().map(|c| c.name.as_str()).collect();
62            let to_cols: std::collections::HashSet<&str> =
63                table.columns.iter().map(|c| c.name.as_str()).collect();
64
65            // New columns
66            for col in &table.columns {
67                if !from_cols.contains(col.name.as_str()) {
68                    changes.push(SchemaChange::AddColumn {
69                        table_name: table.name.clone(),
70                        column: col.clone(),
71                    });
72                }
73            }
74
75            // Dropped columns
76            for col in &from_table.columns {
77                if !to_cols.contains(col.name.as_str()) {
78                    changes.push(SchemaChange::DropColumn {
79                        table_name: table.name.clone(),
80                        column_name: col.name.clone(),
81                    });
82                }
83            }
84        }
85    }
86
87    changes
88}
89
90/// Render schema changes as SQL for the given dialect.
91pub fn render_changes_sql(changes: &[SchemaChange], dialect: crate::Dialect) -> (String, String) {
92    let mut up_sql = String::new();
93    let mut down_sql = String::new();
94
95    for change in changes {
96        match change {
97            SchemaChange::CreateTable(table) => {
98                let schema = YAuthSchema {
99                    tables: vec![table.clone()],
100                };
101                let create = match dialect {
102                    crate::Dialect::Postgres => crate::generate_postgres_ddl(&schema),
103                    crate::Dialect::Sqlite => {
104                        // Don't include PRAGMA for individual table creates
105                        generate_single_table_sqlite(table)
106                    }
107                    crate::Dialect::Mysql => crate::generate_mysql_ddl(&schema),
108                };
109                up_sql.push_str(&create);
110                up_sql.push('\n');
111
112                // Down: drop
113                let drop = match dialect {
114                    crate::Dialect::Postgres => crate::generate_postgres_drop(table),
115                    crate::Dialect::Sqlite => crate::generate_sqlite_drop(table),
116                    crate::Dialect::Mysql => crate::generate_mysql_drop(table),
117                };
118                down_sql.push_str(&drop);
119                down_sql.push('\n');
120            }
121            SchemaChange::DropTable(table) => {
122                let drop = match dialect {
123                    crate::Dialect::Postgres => crate::generate_postgres_drop(table),
124                    crate::Dialect::Sqlite => crate::generate_sqlite_drop(table),
125                    crate::Dialect::Mysql => crate::generate_mysql_drop(table),
126                };
127                up_sql.push_str(&drop);
128                up_sql.push('\n');
129
130                // Down: recreate
131                let schema = YAuthSchema {
132                    tables: vec![table.clone()],
133                };
134                let create = match dialect {
135                    crate::Dialect::Postgres => crate::generate_postgres_ddl(&schema),
136                    crate::Dialect::Sqlite => generate_single_table_sqlite(table),
137                    crate::Dialect::Mysql => crate::generate_mysql_ddl(&schema),
138                };
139                down_sql.push_str(&create);
140                down_sql.push('\n');
141            }
142            SchemaChange::AddColumn { table_name, column } => {
143                let stmt = render_add_column(table_name, column, dialect);
144                up_sql.push_str(&stmt);
145                up_sql.push('\n');
146
147                let drop_stmt = render_drop_column(table_name, &column.name, dialect);
148                down_sql.push_str(&drop_stmt);
149                down_sql.push('\n');
150            }
151            SchemaChange::DropColumn {
152                table_name,
153                column_name,
154            } => {
155                let stmt = render_drop_column(table_name, column_name, dialect);
156                up_sql.push_str(&stmt);
157                up_sql.push('\n');
158                // Down for drop column is hard without the original column def,
159                // so we add a comment.
160                down_sql.push_str(&format!(
161                    "-- TODO: Re-add column {column_name} to {table_name}\n\n"
162                ));
163            }
164        }
165    }
166
167    (up_sql, down_sql)
168}
169
170fn render_add_column(
171    table_name: &str,
172    column: &crate::types::ColumnDef,
173    dialect: crate::Dialect,
174) -> String {
175    match dialect {
176        crate::Dialect::Postgres => {
177            let col_type = pg_type(&column.col_type);
178            let mut stmt = format!(
179                "ALTER TABLE {} ADD COLUMN {} {}",
180                table_name, column.name, col_type
181            );
182            if !column.nullable && column.default.is_none() {
183                // Can't add NOT NULL without a default to a table with existing rows
184                stmt.push_str(" NULL");
185            } else {
186                if !column.nullable {
187                    stmt.push_str(" NOT NULL");
188                }
189                if let Some(ref default) = column.default {
190                    stmt.push_str(&format!(" DEFAULT {}", default));
191                }
192            }
193            stmt.push_str(";\n");
194            stmt
195        }
196        crate::Dialect::Sqlite => {
197            let col_type = sqlite_type(&column.col_type);
198            let mut stmt = format!(
199                "ALTER TABLE {} ADD COLUMN {} {}",
200                table_name, column.name, col_type
201            );
202            if !column.nullable && column.default.is_none() {
203                stmt.push_str(" NULL");
204            } else {
205                if !column.nullable {
206                    stmt.push_str(" NOT NULL");
207                }
208                if let Some(ref default) = column.default
209                    && let Some(d) = sqlite_default(default)
210                {
211                    stmt.push_str(&format!(" DEFAULT {}", d));
212                }
213            }
214            stmt.push_str(";\n");
215            stmt
216        }
217        crate::Dialect::Mysql => {
218            let col_type = mysql_type(&column.col_type);
219            let mut stmt = format!(
220                "ALTER TABLE `{}` ADD COLUMN `{}` {}",
221                table_name, column.name, col_type
222            );
223            if !column.nullable && column.default.is_none() {
224                stmt.push_str(" NULL");
225            } else {
226                if !column.nullable {
227                    stmt.push_str(" NOT NULL");
228                }
229                if let Some(ref default) = column.default
230                    && let Some(d) = mysql_default(default)
231                {
232                    stmt.push_str(&format!(" DEFAULT {}", d));
233                }
234            }
235            stmt.push_str(";\n");
236            stmt
237        }
238    }
239}
240
241fn render_drop_column(table_name: &str, column_name: &str, dialect: crate::Dialect) -> String {
242    match dialect {
243        crate::Dialect::Postgres => {
244            format!(
245                "ALTER TABLE {} DROP COLUMN IF EXISTS {};\n",
246                table_name, column_name
247            )
248        }
249        crate::Dialect::Sqlite => {
250            format!("ALTER TABLE {} DROP COLUMN {};\n", table_name, column_name)
251        }
252        crate::Dialect::Mysql => {
253            format!(
254                "ALTER TABLE `{}` DROP COLUMN `{}`;\n",
255                table_name, column_name
256            )
257        }
258    }
259}
260
261fn generate_single_table_sqlite(table: &TableDef) -> String {
262    // Reuse the full generator but strip the PRAGMA
263    let schema = YAuthSchema {
264        tables: vec![table.clone()],
265    };
266    let full = crate::generate_sqlite_ddl(&schema);
267    // Strip PRAGMA line
268    full.lines()
269        .filter(|l| !l.starts_with("PRAGMA"))
270        .collect::<Vec<_>>()
271        .join("\n")
272        .trim_start_matches('\n')
273        .to_string()
274        + "\n"
275}
276
277/// Format a text diff of two SQL strings for display.
278pub fn format_sql_diff(old: &str, new: &str) -> String {
279    use similar::{ChangeTag, TextDiff};
280
281    let diff = TextDiff::from_lines(old, new);
282    let mut output = String::new();
283
284    for change in diff.iter_all_changes() {
285        let sign = match change.tag() {
286            ChangeTag::Delete => "-",
287            ChangeTag::Insert => "+",
288            ChangeTag::Equal => " ",
289        };
290        output.push_str(sign);
291        output.push_str(change.as_str().unwrap_or(""));
292        if !change.as_str().unwrap_or("").ends_with('\n') {
293            output.push('\n');
294        }
295    }
296
297    output
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::{collect_schema, core_schema, plugin_schemas};
304
305    #[test]
306    fn diff_empty_to_core_creates_tables() {
307        let from = YAuthSchema { tables: vec![] };
308        let to = collect_schema(vec![core_schema()]).unwrap();
309        let changes = schema_diff(&from, &to);
310
311        assert_eq!(changes.len(), 3);
312        assert!(matches!(&changes[0], SchemaChange::CreateTable(t) if t.name == "yauth_users"));
313        assert!(matches!(&changes[1], SchemaChange::CreateTable(t) if t.name == "yauth_sessions"));
314        assert!(matches!(&changes[2], SchemaChange::CreateTable(t) if t.name == "yauth_audit_log"));
315    }
316
317    #[test]
318    fn diff_add_plugin_creates_plugin_tables() {
319        let from = collect_schema(vec![core_schema()]).unwrap();
320        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
321
322        let changes = schema_diff(&from, &to);
323        assert_eq!(changes.len(), 2);
324        assert!(
325            matches!(&changes[0], SchemaChange::CreateTable(t) if t.name == "yauth_totp_secrets")
326        );
327        assert!(
328            matches!(&changes[1], SchemaChange::CreateTable(t) if t.name == "yauth_backup_codes")
329        );
330    }
331
332    #[test]
333    fn diff_remove_plugin_drops_plugin_tables() {
334        let from = collect_schema(vec![core_schema(), plugin_schemas::passkey_schema()]).unwrap();
335        let to = collect_schema(vec![core_schema()]).unwrap();
336
337        let changes = schema_diff(&from, &to);
338        assert_eq!(changes.len(), 1);
339        assert!(
340            matches!(&changes[0], SchemaChange::DropTable(t) if t.name == "yauth_webauthn_credentials")
341        );
342    }
343
344    #[test]
345    fn diff_no_changes() {
346        let schema = collect_schema(vec![core_schema()]).unwrap();
347        let changes = schema_diff(&schema, &schema);
348        assert!(changes.is_empty());
349    }
350
351    #[test]
352    fn diff_add_mfa_produces_valid_postgres_sql() {
353        let from = collect_schema(vec![core_schema()]).unwrap();
354        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
355
356        let changes = schema_diff(&from, &to);
357        let (up, down) = render_changes_sql(&changes, crate::Dialect::Postgres);
358
359        assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_totp_secrets"));
360        assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_backup_codes"));
361        assert!(down.contains("DROP TABLE IF EXISTS yauth_totp_secrets CASCADE"));
362        assert!(down.contains("DROP TABLE IF EXISTS yauth_backup_codes CASCADE"));
363    }
364
365    #[test]
366    fn diff_add_mfa_produces_valid_sqlite_sql() {
367        let from = collect_schema(vec![core_schema()]).unwrap();
368        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
369
370        let changes = schema_diff(&from, &to);
371        let (up, _down) = render_changes_sql(&changes, crate::Dialect::Sqlite);
372
373        assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_totp_secrets"));
374        assert!(!up.contains("PRAGMA")); // Individual table creates shouldn't have PRAGMA
375    }
376
377    #[test]
378    fn diff_add_mfa_produces_valid_mysql_sql() {
379        let from = collect_schema(vec![core_schema()]).unwrap();
380        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
381
382        let changes = schema_diff(&from, &to);
383        let (up, _down) = render_changes_sql(&changes, crate::Dialect::Mysql);
384
385        assert!(up.contains("CREATE TABLE IF NOT EXISTS `yauth_totp_secrets`"));
386        assert!(up.contains("ENGINE=InnoDB"));
387    }
388
389    #[test]
390    fn diff_complex_add_and_remove() {
391        // Start with email-password + passkey, end with email-password + mfa
392        let from = collect_schema(vec![
393            core_schema(),
394            plugin_schemas::email_password_schema(),
395            plugin_schemas::passkey_schema(),
396        ])
397        .unwrap();
398        let to = collect_schema(vec![
399            core_schema(),
400            plugin_schemas::email_password_schema(),
401            plugin_schemas::mfa_schema(),
402        ])
403        .unwrap();
404
405        let changes = schema_diff(&from, &to);
406
407        // Should create mfa tables and drop passkey table
408        let creates: Vec<_> = changes
409            .iter()
410            .filter(|c| matches!(c, SchemaChange::CreateTable(_)))
411            .collect();
412        let drops: Vec<_> = changes
413            .iter()
414            .filter(|c| matches!(c, SchemaChange::DropTable(_)))
415            .collect();
416
417        assert_eq!(creates.len(), 2); // totp_secrets + backup_codes
418        assert_eq!(drops.len(), 1); // webauthn_credentials
419    }
420
421    #[test]
422    fn format_diff_shows_additions() {
423        let old = "line1\nline2\n";
424        let new = "line1\nline2\nline3\n";
425        let diff = format_sql_diff(old, new);
426        assert!(diff.contains("+line3"));
427    }
428}