1use crate::collector::YAuthSchema;
7use crate::mysql::{mysql_default, mysql_type};
8use crate::postgres::pg_type;
9use crate::sqlite::{sqlite_default, sqlite_type};
10use crate::types::TableDef;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum SchemaChange {
15 CreateTable(TableDef),
17 DropTable(TableDef),
19 AddColumn {
21 table_name: String,
22 column: crate::types::ColumnDef,
23 },
24 DropColumn {
26 table_name: String,
27 column_name: String,
28 },
29}
30
31pub fn schema_diff(from: &YAuthSchema, to: &YAuthSchema) -> Vec<SchemaChange> {
36 let mut changes = Vec::new();
37
38 let from_tables: std::collections::HashMap<&str, &TableDef> =
39 from.tables.iter().map(|t| (t.name.as_str(), t)).collect();
40 let to_tables: std::collections::HashMap<&str, &TableDef> =
41 to.tables.iter().map(|t| (t.name.as_str(), t)).collect();
42
43 for table in &to.tables {
45 if !from_tables.contains_key(table.name.as_str()) {
46 changes.push(SchemaChange::CreateTable(table.clone()));
47 }
48 }
49
50 for table in from.tables.iter().rev() {
52 if !to_tables.contains_key(table.name.as_str()) {
53 changes.push(SchemaChange::DropTable(table.clone()));
54 }
55 }
56
57 for table in &to.tables {
59 if let Some(from_table) = from_tables.get(table.name.as_str()) {
60 let from_cols: std::collections::HashSet<&str> =
61 from_table.columns.iter().map(|c| c.name.as_str()).collect();
62 let to_cols: std::collections::HashSet<&str> =
63 table.columns.iter().map(|c| c.name.as_str()).collect();
64
65 for col in &table.columns {
67 if !from_cols.contains(col.name.as_str()) {
68 changes.push(SchemaChange::AddColumn {
69 table_name: table.name.clone(),
70 column: col.clone(),
71 });
72 }
73 }
74
75 for col in &from_table.columns {
77 if !to_cols.contains(col.name.as_str()) {
78 changes.push(SchemaChange::DropColumn {
79 table_name: table.name.clone(),
80 column_name: col.name.clone(),
81 });
82 }
83 }
84 }
85 }
86
87 changes
88}
89
90pub fn render_changes_sql(changes: &[SchemaChange], dialect: crate::Dialect) -> (String, String) {
92 let mut up_sql = String::new();
93 let mut down_sql = String::new();
94
95 for change in changes {
96 match change {
97 SchemaChange::CreateTable(table) => {
98 let schema = YAuthSchema {
99 tables: vec![table.clone()],
100 };
101 let create = match dialect {
102 crate::Dialect::Postgres => crate::generate_postgres_ddl(&schema),
103 crate::Dialect::Sqlite => {
104 generate_single_table_sqlite(table)
106 }
107 crate::Dialect::Mysql => crate::generate_mysql_ddl(&schema),
108 };
109 up_sql.push_str(&create);
110 up_sql.push('\n');
111
112 let drop = match dialect {
114 crate::Dialect::Postgres => crate::generate_postgres_drop(table),
115 crate::Dialect::Sqlite => crate::generate_sqlite_drop(table),
116 crate::Dialect::Mysql => crate::generate_mysql_drop(table),
117 };
118 down_sql.push_str(&drop);
119 down_sql.push('\n');
120 }
121 SchemaChange::DropTable(table) => {
122 let drop = match dialect {
123 crate::Dialect::Postgres => crate::generate_postgres_drop(table),
124 crate::Dialect::Sqlite => crate::generate_sqlite_drop(table),
125 crate::Dialect::Mysql => crate::generate_mysql_drop(table),
126 };
127 up_sql.push_str(&drop);
128 up_sql.push('\n');
129
130 let schema = YAuthSchema {
132 tables: vec![table.clone()],
133 };
134 let create = match dialect {
135 crate::Dialect::Postgres => crate::generate_postgres_ddl(&schema),
136 crate::Dialect::Sqlite => generate_single_table_sqlite(table),
137 crate::Dialect::Mysql => crate::generate_mysql_ddl(&schema),
138 };
139 down_sql.push_str(&create);
140 down_sql.push('\n');
141 }
142 SchemaChange::AddColumn { table_name, column } => {
143 let stmt = render_add_column(table_name, column, dialect);
144 up_sql.push_str(&stmt);
145 up_sql.push('\n');
146
147 let drop_stmt = render_drop_column(table_name, &column.name, dialect);
148 down_sql.push_str(&drop_stmt);
149 down_sql.push('\n');
150 }
151 SchemaChange::DropColumn {
152 table_name,
153 column_name,
154 } => {
155 let stmt = render_drop_column(table_name, column_name, dialect);
156 up_sql.push_str(&stmt);
157 up_sql.push('\n');
158 down_sql.push_str(&format!(
161 "-- TODO: Re-add column {column_name} to {table_name}\n\n"
162 ));
163 }
164 }
165 }
166
167 (up_sql, down_sql)
168}
169
170fn render_add_column(
171 table_name: &str,
172 column: &crate::types::ColumnDef,
173 dialect: crate::Dialect,
174) -> String {
175 match dialect {
176 crate::Dialect::Postgres => {
177 let col_type = pg_type(&column.col_type);
178 let mut stmt = format!(
179 "ALTER TABLE {} ADD COLUMN {} {}",
180 table_name, column.name, col_type
181 );
182 if !column.nullable && column.default.is_none() {
183 stmt.push_str(" NULL");
185 } else {
186 if !column.nullable {
187 stmt.push_str(" NOT NULL");
188 }
189 if let Some(ref default) = column.default {
190 stmt.push_str(&format!(" DEFAULT {}", default));
191 }
192 }
193 stmt.push_str(";\n");
194 stmt
195 }
196 crate::Dialect::Sqlite => {
197 let col_type = sqlite_type(&column.col_type);
198 let mut stmt = format!(
199 "ALTER TABLE {} ADD COLUMN {} {}",
200 table_name, column.name, col_type
201 );
202 if !column.nullable && column.default.is_none() {
203 stmt.push_str(" NULL");
204 } else {
205 if !column.nullable {
206 stmt.push_str(" NOT NULL");
207 }
208 if let Some(ref default) = column.default
209 && let Some(d) = sqlite_default(default)
210 {
211 stmt.push_str(&format!(" DEFAULT {}", d));
212 }
213 }
214 stmt.push_str(";\n");
215 stmt
216 }
217 crate::Dialect::Mysql => {
218 let col_type = mysql_type(&column.col_type);
219 let mut stmt = format!(
220 "ALTER TABLE `{}` ADD COLUMN `{}` {}",
221 table_name, column.name, col_type
222 );
223 if !column.nullable && column.default.is_none() {
224 stmt.push_str(" NULL");
225 } else {
226 if !column.nullable {
227 stmt.push_str(" NOT NULL");
228 }
229 if let Some(ref default) = column.default
230 && let Some(d) = mysql_default(default)
231 {
232 stmt.push_str(&format!(" DEFAULT {}", d));
233 }
234 }
235 stmt.push_str(";\n");
236 stmt
237 }
238 }
239}
240
241fn render_drop_column(table_name: &str, column_name: &str, dialect: crate::Dialect) -> String {
242 match dialect {
243 crate::Dialect::Postgres => {
244 format!(
245 "ALTER TABLE {} DROP COLUMN IF EXISTS {};\n",
246 table_name, column_name
247 )
248 }
249 crate::Dialect::Sqlite => {
250 format!("ALTER TABLE {} DROP COLUMN {};\n", table_name, column_name)
251 }
252 crate::Dialect::Mysql => {
253 format!(
254 "ALTER TABLE `{}` DROP COLUMN `{}`;\n",
255 table_name, column_name
256 )
257 }
258 }
259}
260
261fn generate_single_table_sqlite(table: &TableDef) -> String {
262 let schema = YAuthSchema {
264 tables: vec![table.clone()],
265 };
266 let full = crate::generate_sqlite_ddl(&schema);
267 full.lines()
269 .filter(|l| !l.starts_with("PRAGMA"))
270 .collect::<Vec<_>>()
271 .join("\n")
272 .trim_start_matches('\n')
273 .to_string()
274 + "\n"
275}
276
277pub fn format_sql_diff(old: &str, new: &str) -> String {
279 use similar::{ChangeTag, TextDiff};
280
281 let diff = TextDiff::from_lines(old, new);
282 let mut output = String::new();
283
284 for change in diff.iter_all_changes() {
285 let sign = match change.tag() {
286 ChangeTag::Delete => "-",
287 ChangeTag::Insert => "+",
288 ChangeTag::Equal => " ",
289 };
290 output.push_str(sign);
291 output.push_str(change.as_str().unwrap_or(""));
292 if !change.as_str().unwrap_or("").ends_with('\n') {
293 output.push('\n');
294 }
295 }
296
297 output
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::{collect_schema, core_schema, plugin_schemas};
304
305 #[test]
306 fn diff_empty_to_core_creates_tables() {
307 let from = YAuthSchema { tables: vec![] };
308 let to = collect_schema(vec![core_schema()]).unwrap();
309 let changes = schema_diff(&from, &to);
310
311 assert_eq!(changes.len(), 6);
312 let table_names: Vec<&str> = changes
313 .iter()
314 .filter_map(|c| match c {
315 SchemaChange::CreateTable(t) => Some(t.name.as_str()),
316 _ => None,
317 })
318 .collect();
319 for expected in &[
320 "yauth_users",
321 "yauth_sessions",
322 "yauth_audit_log",
323 "yauth_challenges",
324 "yauth_rate_limits",
325 "yauth_revocations",
326 ] {
327 assert!(table_names.contains(expected), "Missing table: {expected}");
328 }
329 }
330
331 #[test]
332 fn diff_add_plugin_creates_plugin_tables() {
333 let from = collect_schema(vec![core_schema()]).unwrap();
334 let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
335
336 let changes = schema_diff(&from, &to);
337 assert_eq!(changes.len(), 2);
338 assert!(
339 matches!(&changes[0], SchemaChange::CreateTable(t) if t.name == "yauth_totp_secrets")
340 );
341 assert!(
342 matches!(&changes[1], SchemaChange::CreateTable(t) if t.name == "yauth_backup_codes")
343 );
344 }
345
346 #[test]
347 fn diff_remove_plugin_drops_plugin_tables() {
348 let from = collect_schema(vec![core_schema(), plugin_schemas::passkey_schema()]).unwrap();
349 let to = collect_schema(vec![core_schema()]).unwrap();
350
351 let changes = schema_diff(&from, &to);
352 assert_eq!(changes.len(), 1);
353 assert!(
354 matches!(&changes[0], SchemaChange::DropTable(t) if t.name == "yauth_webauthn_credentials")
355 );
356 }
357
358 #[test]
359 fn diff_no_changes() {
360 let schema = collect_schema(vec![core_schema()]).unwrap();
361 let changes = schema_diff(&schema, &schema);
362 assert!(changes.is_empty());
363 }
364
365 #[test]
366 fn diff_add_mfa_produces_valid_postgres_sql() {
367 let from = collect_schema(vec![core_schema()]).unwrap();
368 let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
369
370 let changes = schema_diff(&from, &to);
371 let (up, down) = render_changes_sql(&changes, crate::Dialect::Postgres);
372
373 assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_totp_secrets"));
374 assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_backup_codes"));
375 assert!(down.contains("DROP TABLE IF EXISTS yauth_totp_secrets CASCADE"));
376 assert!(down.contains("DROP TABLE IF EXISTS yauth_backup_codes CASCADE"));
377 }
378
379 #[test]
380 fn diff_add_mfa_produces_valid_sqlite_sql() {
381 let from = collect_schema(vec![core_schema()]).unwrap();
382 let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
383
384 let changes = schema_diff(&from, &to);
385 let (up, _down) = render_changes_sql(&changes, crate::Dialect::Sqlite);
386
387 assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_totp_secrets"));
388 assert!(!up.contains("PRAGMA")); }
390
391 #[test]
392 fn diff_add_mfa_produces_valid_mysql_sql() {
393 let from = collect_schema(vec![core_schema()]).unwrap();
394 let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
395
396 let changes = schema_diff(&from, &to);
397 let (up, _down) = render_changes_sql(&changes, crate::Dialect::Mysql);
398
399 assert!(up.contains("CREATE TABLE IF NOT EXISTS `yauth_totp_secrets`"));
400 assert!(up.contains("ENGINE=InnoDB"));
401 }
402
403 #[test]
404 fn diff_complex_add_and_remove() {
405 let from = collect_schema(vec![
407 core_schema(),
408 plugin_schemas::email_password_schema(),
409 plugin_schemas::passkey_schema(),
410 ])
411 .unwrap();
412 let to = collect_schema(vec![
413 core_schema(),
414 plugin_schemas::email_password_schema(),
415 plugin_schemas::mfa_schema(),
416 ])
417 .unwrap();
418
419 let changes = schema_diff(&from, &to);
420
421 let creates: Vec<_> = changes
423 .iter()
424 .filter(|c| matches!(c, SchemaChange::CreateTable(_)))
425 .collect();
426 let drops: Vec<_> = changes
427 .iter()
428 .filter(|c| matches!(c, SchemaChange::DropTable(_)))
429 .collect();
430
431 assert_eq!(creates.len(), 2); assert_eq!(drops.len(), 1); }
434
435 #[test]
436 fn format_diff_shows_additions() {
437 let old = "line1\nline2\n";
438 let new = "line1\nline2\nline3\n";
439 let diff = format_sql_diff(old, new);
440 assert!(diff.contains("+line3"));
441 }
442}