1use std::collections::HashMap;
23
24use crate::ast::*;
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum ChangeAction {
29 Remove(AstNode),
31 Insert(AstNode),
33 Keep(AstNode, AstNode),
35 Move(AstNode, AstNode),
37 Update(AstNode, AstNode),
39}
40
41#[derive(Debug, Clone, PartialEq)]
44pub enum AstNode {
45 Statement(Box<Statement>),
46 Expr(Expr),
47 SelectItem(SelectItem),
48 JoinClause(JoinClause),
49 OrderByItem(OrderByItem),
50 Cte(Box<Cte>),
51 ColumnDef(ColumnDef),
52 TableConstraint(TableConstraint),
53}
54
55impl std::fmt::Display for AstNode {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 AstNode::Statement(s) => write!(f, "{s:?}"),
59 AstNode::Expr(e) => write!(f, "{e:?}"),
60 AstNode::SelectItem(si) => write!(f, "{si:?}"),
61 AstNode::JoinClause(j) => write!(f, "{j:?}"),
62 AstNode::OrderByItem(o) => write!(f, "{o:?}"),
63 AstNode::Cte(c) => write!(f, "{c:?}"),
64 AstNode::ColumnDef(cd) => write!(f, "{cd:?}"),
65 AstNode::TableConstraint(tc) => write!(f, "{tc:?}"),
66 }
67 }
68}
69
70#[must_use]
75pub fn diff(source: &Statement, target: &Statement) -> Vec<ChangeAction> {
76 let mut differ = AstDiffer::new();
77 differ.diff_statements(source, target);
78 differ.changes
79}
80
81struct AstDiffer {
83 changes: Vec<ChangeAction>,
84}
85
86impl AstDiffer {
87 fn new() -> Self {
88 Self {
89 changes: Vec::new(),
90 }
91 }
92
93 fn diff_statements(&mut self, source: &Statement, target: &Statement) {
94 use Statement::*;
95
96 match (source, target) {
97 (Select(s), Select(t)) => self.diff_select(s, t),
98 (Insert(s), Insert(t)) => self.diff_insert(s, t),
99 (Update(s), Update(t)) => self.diff_update(s, t),
100 (Delete(s), Delete(t)) => self.diff_delete(s, t),
101 (CreateTable(s), CreateTable(t)) => self.diff_create_table(s, t),
102 (DropTable(s), DropTable(t)) => self.diff_drop_table(s, t),
103 (SetOperation(s), SetOperation(t)) => self.diff_set_operation(s, t),
104 (AlterTable(s), AlterTable(t)) => self.diff_alter_table(s, t),
105 (CreateView(s), CreateView(t)) => self.diff_create_view(s, t),
106 (Expression(s), Expression(t)) => self.diff_exprs(s, t),
107 _ => {
108 self.changes
110 .push(ChangeAction::Remove(AstNode::Statement(Box::new(
111 source.clone(),
112 ))));
113 self.changes
114 .push(ChangeAction::Insert(AstNode::Statement(Box::new(
115 target.clone(),
116 ))));
117 }
118 }
119 }
120
121 fn diff_select(&mut self, source: &SelectStatement, target: &SelectStatement) {
124 self.diff_ctes(&source.ctes, &target.ctes);
126
127 if source.distinct != target.distinct {
129 if target.distinct {
130 self.changes
131 .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
132 table: None,
133 name: "DISTINCT".to_string(),
134 quote_style: QuoteStyle::None,
135 table_quote_style: QuoteStyle::None,
136 })));
137 } else {
138 self.changes
139 .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
140 table: None,
141 name: "DISTINCT".to_string(),
142 quote_style: QuoteStyle::None,
143 table_quote_style: QuoteStyle::None,
144 })));
145 }
146 }
147
148 self.diff_select_items(&source.columns, &target.columns);
150
151 match (&source.from, &target.from) {
153 (Some(sf), Some(tf)) => self.diff_table_sources(&sf.source, &tf.source),
154 (None, Some(tf)) => self.insert_table_source(&tf.source),
155 (Some(sf), None) => self.remove_table_source(&sf.source),
156 (None, None) => {}
157 }
158
159 self.diff_joins(&source.joins, &target.joins);
161
162 self.diff_optional_exprs(&source.where_clause, &target.where_clause);
164
165 self.diff_expr_lists(&source.group_by, &target.group_by);
167
168 self.diff_optional_exprs(&source.having, &target.having);
170
171 self.diff_order_by(&source.order_by, &target.order_by);
173
174 self.diff_optional_exprs(&source.limit, &target.limit);
176
177 self.diff_optional_exprs(&source.offset, &target.offset);
179
180 self.diff_optional_exprs(&source.qualify, &target.qualify);
182 }
183
184 fn diff_insert(&mut self, source: &InsertStatement, target: &InsertStatement) {
187 if source.table != target.table {
188 self.changes.push(ChangeAction::Update(
189 AstNode::Expr(table_ref_to_expr(&source.table)),
190 AstNode::Expr(table_ref_to_expr(&target.table)),
191 ));
192 }
193
194 self.diff_string_lists(&source.columns, &target.columns);
196
197 match (&source.source, &target.source) {
199 (InsertSource::Values(sv), InsertSource::Values(tv)) => {
200 for (i, (sr, tr)) in sv.iter().zip(tv.iter()).enumerate() {
201 self.diff_expr_lists(sr, tr);
202 let _ = i;
203 }
204 for extra in sv.iter().skip(tv.len()) {
205 for e in extra {
206 self.changes
207 .push(ChangeAction::Remove(AstNode::Expr(e.clone())));
208 }
209 }
210 for extra in tv.iter().skip(sv.len()) {
211 for e in extra {
212 self.changes
213 .push(ChangeAction::Insert(AstNode::Expr(e.clone())));
214 }
215 }
216 }
217 (InsertSource::Query(sq), InsertSource::Query(tq)) => {
218 self.diff_statements(sq, tq);
219 }
220 _ => {
221 self.changes
222 .push(ChangeAction::Remove(AstNode::Statement(Box::new(
223 Statement::Insert(source.clone()),
224 ))));
225 self.changes
226 .push(ChangeAction::Insert(AstNode::Statement(Box::new(
227 Statement::Insert(target.clone()),
228 ))));
229 }
230 }
231 }
232
233 fn diff_update(&mut self, source: &UpdateStatement, target: &UpdateStatement) {
236 if source.table != target.table {
237 self.changes.push(ChangeAction::Update(
238 AstNode::Expr(table_ref_to_expr(&source.table)),
239 AstNode::Expr(table_ref_to_expr(&target.table)),
240 ));
241 }
242
243 let source_map: HashMap<&str, &Expr> = source
245 .assignments
246 .iter()
247 .map(|(k, v)| (k.as_str(), v))
248 .collect();
249 let target_map: HashMap<&str, &Expr> = target
250 .assignments
251 .iter()
252 .map(|(k, v)| (k.as_str(), v))
253 .collect();
254
255 for (col, src_val) in &source_map {
256 if let Some(tgt_val) = target_map.get(col) {
257 self.diff_exprs(src_val, tgt_val);
258 } else {
259 self.changes
260 .push(ChangeAction::Remove(AstNode::Expr((*src_val).clone())));
261 }
262 }
263 for (col, tgt_val) in &target_map {
264 if !source_map.contains_key(col) {
265 self.changes
266 .push(ChangeAction::Insert(AstNode::Expr((*tgt_val).clone())));
267 }
268 }
269
270 self.diff_optional_exprs(&source.where_clause, &target.where_clause);
271 }
272
273 fn diff_delete(&mut self, source: &DeleteStatement, target: &DeleteStatement) {
276 if source.table != target.table {
277 self.changes.push(ChangeAction::Update(
278 AstNode::Expr(table_ref_to_expr(&source.table)),
279 AstNode::Expr(table_ref_to_expr(&target.table)),
280 ));
281 }
282 self.diff_optional_exprs(&source.where_clause, &target.where_clause);
283 }
284
285 fn diff_create_table(&mut self, source: &CreateTableStatement, target: &CreateTableStatement) {
288 if source.table != target.table {
289 self.changes.push(ChangeAction::Update(
290 AstNode::Expr(table_ref_to_expr(&source.table)),
291 AstNode::Expr(table_ref_to_expr(&target.table)),
292 ));
293 }
294
295 let source_cols: HashMap<&str, &ColumnDef> = source
297 .columns
298 .iter()
299 .map(|c| (c.name.as_str(), c))
300 .collect();
301 let target_cols: HashMap<&str, &ColumnDef> = target
302 .columns
303 .iter()
304 .map(|c| (c.name.as_str(), c))
305 .collect();
306
307 for (name, src_col) in &source_cols {
308 if let Some(tgt_col) = target_cols.get(name) {
309 if src_col != tgt_col {
310 self.changes.push(ChangeAction::Update(
311 AstNode::ColumnDef((*src_col).clone()),
312 AstNode::ColumnDef((*tgt_col).clone()),
313 ));
314 } else {
315 self.changes.push(ChangeAction::Keep(
316 AstNode::ColumnDef((*src_col).clone()),
317 AstNode::ColumnDef((*tgt_col).clone()),
318 ));
319 }
320 } else {
321 self.changes
322 .push(ChangeAction::Remove(AstNode::ColumnDef((*src_col).clone())));
323 }
324 }
325 for (name, tgt_col) in &target_cols {
326 if !source_cols.contains_key(name) {
327 self.changes
328 .push(ChangeAction::Insert(AstNode::ColumnDef((*tgt_col).clone())));
329 }
330 }
331
332 self.diff_constraints(&source.constraints, &target.constraints);
334 }
335
336 fn diff_drop_table(&mut self, source: &DropTableStatement, target: &DropTableStatement) {
339 if source != target {
340 self.changes.push(ChangeAction::Update(
341 AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
342 AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
343 ));
344 } else {
345 self.changes.push(ChangeAction::Keep(
346 AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
347 AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
348 ));
349 }
350 }
351
352 fn diff_set_operation(
355 &mut self,
356 source: &SetOperationStatement,
357 target: &SetOperationStatement,
358 ) {
359 if source.op != target.op || source.all != target.all {
360 self.changes.push(ChangeAction::Update(
361 AstNode::Statement(Box::new(Statement::SetOperation(source.clone()))),
362 AstNode::Statement(Box::new(Statement::SetOperation(target.clone()))),
363 ));
364 return;
365 }
366 self.diff_statements(&source.left, &target.left);
367 self.diff_statements(&source.right, &target.right);
368 self.diff_order_by(&source.order_by, &target.order_by);
369 self.diff_optional_exprs(&source.limit, &target.limit);
370 self.diff_optional_exprs(&source.offset, &target.offset);
371 }
372
373 fn diff_alter_table(&mut self, source: &AlterTableStatement, target: &AlterTableStatement) {
376 if source.table != target.table {
377 self.changes.push(ChangeAction::Update(
378 AstNode::Expr(table_ref_to_expr(&source.table)),
379 AstNode::Expr(table_ref_to_expr(&target.table)),
380 ));
381 }
382 if source.actions != target.actions {
384 self.changes.push(ChangeAction::Update(
385 AstNode::Statement(Box::new(Statement::AlterTable(source.clone()))),
386 AstNode::Statement(Box::new(Statement::AlterTable(target.clone()))),
387 ));
388 }
389 }
390
391 fn diff_create_view(&mut self, source: &CreateViewStatement, target: &CreateViewStatement) {
394 if source.name != target.name {
395 self.changes.push(ChangeAction::Update(
396 AstNode::Expr(table_ref_to_expr(&source.name)),
397 AstNode::Expr(table_ref_to_expr(&target.name)),
398 ));
399 }
400 self.diff_statements(&source.query, &target.query);
401 }
402
403 fn diff_exprs(&mut self, source: &Expr, target: &Expr) {
406 if source == target {
407 self.changes.push(ChangeAction::Keep(
408 AstNode::Expr(source.clone()),
409 AstNode::Expr(target.clone()),
410 ));
411 return;
412 }
413
414 match (source, target) {
416 (
417 Expr::BinaryOp {
418 left: sl,
419 op: sop,
420 right: sr,
421 },
422 Expr::BinaryOp {
423 left: tl,
424 op: top,
425 right: tr,
426 },
427 ) => {
428 if sop != top {
429 self.changes.push(ChangeAction::Update(
430 AstNode::Expr(source.clone()),
431 AstNode::Expr(target.clone()),
432 ));
433 } else {
434 self.diff_exprs(sl, tl);
435 self.diff_exprs(sr, tr);
436 }
437 }
438 (Expr::UnaryOp { op: sop, expr: se }, Expr::UnaryOp { op: top, expr: te }) => {
439 if sop != top {
440 self.changes.push(ChangeAction::Update(
441 AstNode::Expr(source.clone()),
442 AstNode::Expr(target.clone()),
443 ));
444 } else {
445 self.diff_exprs(se, te);
446 }
447 }
448 (
449 Expr::Function {
450 name: sn,
451 args: sa,
452 distinct: sd,
453 ..
454 },
455 Expr::Function {
456 name: tn,
457 args: ta,
458 distinct: td,
459 ..
460 },
461 ) => {
462 if sn != tn || sd != td {
463 self.changes.push(ChangeAction::Update(
464 AstNode::Expr(source.clone()),
465 AstNode::Expr(target.clone()),
466 ));
467 } else {
468 self.diff_expr_lists(sa, ta);
469 }
470 }
471 (
472 Expr::Cast {
473 expr: se,
474 data_type: sd,
475 },
476 Expr::Cast {
477 expr: te,
478 data_type: td,
479 },
480 ) => {
481 if sd != td {
482 self.changes.push(ChangeAction::Update(
483 AstNode::Expr(source.clone()),
484 AstNode::Expr(target.clone()),
485 ));
486 } else {
487 self.diff_exprs(se, te);
488 }
489 }
490 (
491 Expr::Case {
492 operand: so,
493 when_clauses: sw,
494 else_clause: se,
495 },
496 Expr::Case {
497 operand: to,
498 when_clauses: tw,
499 else_clause: te,
500 },
501 ) => {
502 self.diff_optional_boxed_exprs(so, to);
503 for (i, ((sc, sr), (tc, tr))) in sw.iter().zip(tw.iter()).enumerate() {
505 self.diff_exprs(sc, tc);
506 self.diff_exprs(sr, tr);
507 let _ = i;
508 }
509 for (sc, sr) in sw.iter().skip(tw.len()) {
510 self.changes
511 .push(ChangeAction::Remove(AstNode::Expr(sc.clone())));
512 self.changes
513 .push(ChangeAction::Remove(AstNode::Expr(sr.clone())));
514 }
515 for (tc, tr) in tw.iter().skip(sw.len()) {
516 self.changes
517 .push(ChangeAction::Insert(AstNode::Expr(tc.clone())));
518 self.changes
519 .push(ChangeAction::Insert(AstNode::Expr(tr.clone())));
520 }
521 self.diff_optional_boxed_exprs(se, te);
522 }
523 (Expr::Nested(se), Expr::Nested(te)) => self.diff_exprs(se, te),
524 (
525 Expr::Between {
526 expr: se,
527 low: sl,
528 high: sh,
529 negated: sn,
530 },
531 Expr::Between {
532 expr: te,
533 low: tl,
534 high: th,
535 negated: tn,
536 },
537 ) => {
538 if sn != tn {
539 self.changes.push(ChangeAction::Update(
540 AstNode::Expr(source.clone()),
541 AstNode::Expr(target.clone()),
542 ));
543 } else {
544 self.diff_exprs(se, te);
545 self.diff_exprs(sl, tl);
546 self.diff_exprs(sh, th);
547 }
548 }
549 (
550 Expr::InList {
551 expr: se,
552 list: sl,
553 negated: sn,
554 },
555 Expr::InList {
556 expr: te,
557 list: tl,
558 negated: tn,
559 },
560 ) => {
561 if sn != tn {
562 self.changes.push(ChangeAction::Update(
563 AstNode::Expr(source.clone()),
564 AstNode::Expr(target.clone()),
565 ));
566 } else {
567 self.diff_exprs(se, te);
568 self.diff_expr_lists(sl, tl);
569 }
570 }
571 (
572 Expr::InSubquery {
573 expr: se,
574 subquery: ss,
575 negated: sn,
576 },
577 Expr::InSubquery {
578 expr: te,
579 subquery: ts,
580 negated: tn,
581 },
582 ) => {
583 if sn != tn {
584 self.changes.push(ChangeAction::Update(
585 AstNode::Expr(source.clone()),
586 AstNode::Expr(target.clone()),
587 ));
588 } else {
589 self.diff_exprs(se, te);
590 self.diff_statements(ss, ts);
591 }
592 }
593 (
594 Expr::IsNull {
595 expr: se,
596 negated: sn,
597 },
598 Expr::IsNull {
599 expr: te,
600 negated: tn,
601 },
602 ) => {
603 if sn != tn {
604 self.changes.push(ChangeAction::Update(
605 AstNode::Expr(source.clone()),
606 AstNode::Expr(target.clone()),
607 ));
608 } else {
609 self.diff_exprs(se, te);
610 }
611 }
612 (
613 Expr::Like {
614 expr: se,
615 pattern: sp,
616 negated: sn,
617 ..
618 },
619 Expr::Like {
620 expr: te,
621 pattern: tp,
622 negated: tn,
623 ..
624 },
625 )
626 | (
627 Expr::ILike {
628 expr: se,
629 pattern: sp,
630 negated: sn,
631 ..
632 },
633 Expr::ILike {
634 expr: te,
635 pattern: tp,
636 negated: tn,
637 ..
638 },
639 ) => {
640 if sn != tn {
641 self.changes.push(ChangeAction::Update(
642 AstNode::Expr(source.clone()),
643 AstNode::Expr(target.clone()),
644 ));
645 } else {
646 self.diff_exprs(se, te);
647 self.diff_exprs(sp, tp);
648 }
649 }
650 (Expr::Subquery(ss), Expr::Subquery(ts)) => self.diff_statements(ss, ts),
651 (
652 Expr::Exists {
653 subquery: ss,
654 negated: sn,
655 },
656 Expr::Exists {
657 subquery: ts,
658 negated: tn,
659 },
660 ) => {
661 if sn != tn {
662 self.changes.push(ChangeAction::Update(
663 AstNode::Expr(source.clone()),
664 AstNode::Expr(target.clone()),
665 ));
666 } else {
667 self.diff_statements(ss, ts);
668 }
669 }
670 (Expr::Alias { expr: se, name: sn }, Expr::Alias { expr: te, name: tn }) => {
671 if sn != tn {
672 self.changes.push(ChangeAction::Update(
673 AstNode::Expr(source.clone()),
674 AstNode::Expr(target.clone()),
675 ));
676 } else {
677 self.diff_exprs(se, te);
678 }
679 }
680 (Expr::Coalesce(sa), Expr::Coalesce(ta)) => self.diff_expr_lists(sa, ta),
681 (Expr::ArrayLiteral(sa), Expr::ArrayLiteral(ta)) => self.diff_expr_lists(sa, ta),
682 (Expr::Tuple(sa), Expr::Tuple(ta)) => self.diff_expr_lists(sa, ta),
683 (Expr::TypedFunction { func: sf, .. }, Expr::TypedFunction { func: tf, .. }) => {
684 if std::mem::discriminant(sf) == std::mem::discriminant(tf) && source == target {
685 self.changes.push(ChangeAction::Keep(
686 AstNode::Expr(source.clone()),
687 AstNode::Expr(target.clone()),
688 ));
689 } else {
690 self.changes.push(ChangeAction::Update(
691 AstNode::Expr(source.clone()),
692 AstNode::Expr(target.clone()),
693 ));
694 }
695 }
696 _ => {
698 self.changes.push(ChangeAction::Update(
699 AstNode::Expr(source.clone()),
700 AstNode::Expr(target.clone()),
701 ));
702 }
703 }
704 }
705
706 fn diff_expr_lists(&mut self, source: &[Expr], target: &[Expr]) {
708 let lcs = compute_lcs(source, target);
710 let mut si = 0;
711 let mut ti = 0;
712 let mut li = 0;
713
714 while si < source.len() || ti < target.len() {
715 if li < lcs.len() {
716 let (lcs_si, lcs_ti) = lcs[li];
717
718 while si < lcs_si {
720 self.changes
721 .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
722 si += 1;
723 }
724 while ti < lcs_ti {
726 self.changes
727 .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
728 ti += 1;
729 }
730 self.diff_exprs(&source[si], &target[ti]);
732 si += 1;
733 ti += 1;
734 li += 1;
735 } else {
736 while si < source.len() {
738 self.changes
739 .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
740 si += 1;
741 }
742 while ti < target.len() {
744 self.changes
745 .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
746 ti += 1;
747 }
748 }
749 }
750 }
751
752 fn diff_select_items(&mut self, source: &[SelectItem], target: &[SelectItem]) {
753 let min_len = source.len().min(target.len());
754 for i in 0..min_len {
755 if source[i] == target[i] {
756 self.changes.push(ChangeAction::Keep(
757 AstNode::SelectItem(source[i].clone()),
758 AstNode::SelectItem(target[i].clone()),
759 ));
760 } else {
761 match (&source[i], &target[i]) {
762 (
763 SelectItem::Expr {
764 expr: se,
765 alias: sa,
766 },
767 SelectItem::Expr {
768 expr: te,
769 alias: ta,
770 },
771 ) => {
772 if sa != ta {
773 self.changes.push(ChangeAction::Update(
774 AstNode::SelectItem(source[i].clone()),
775 AstNode::SelectItem(target[i].clone()),
776 ));
777 } else {
778 self.diff_exprs(se, te);
779 }
780 }
781 _ => {
782 self.changes.push(ChangeAction::Update(
783 AstNode::SelectItem(source[i].clone()),
784 AstNode::SelectItem(target[i].clone()),
785 ));
786 }
787 }
788 }
789 }
790 for item in source.iter().skip(min_len) {
791 self.changes
792 .push(ChangeAction::Remove(AstNode::SelectItem(item.clone())));
793 }
794 for item in target.iter().skip(min_len) {
795 self.changes
796 .push(ChangeAction::Insert(AstNode::SelectItem(item.clone())));
797 }
798 }
799
800 fn diff_optional_exprs(&mut self, source: &Option<Expr>, target: &Option<Expr>) {
801 match (source, target) {
802 (Some(s), Some(t)) => self.diff_exprs(s, t),
803 (None, Some(t)) => self
804 .changes
805 .push(ChangeAction::Insert(AstNode::Expr(t.clone()))),
806 (Some(s), None) => self
807 .changes
808 .push(ChangeAction::Remove(AstNode::Expr(s.clone()))),
809 (None, None) => {}
810 }
811 }
812
813 fn diff_optional_boxed_exprs(
814 &mut self,
815 source: &Option<Box<Expr>>,
816 target: &Option<Box<Expr>>,
817 ) {
818 match (source, target) {
819 (Some(s), Some(t)) => self.diff_exprs(s, t),
820 (None, Some(t)) => self
821 .changes
822 .push(ChangeAction::Insert(AstNode::Expr((**t).clone()))),
823 (Some(s), None) => self
824 .changes
825 .push(ChangeAction::Remove(AstNode::Expr((**s).clone()))),
826 (None, None) => {}
827 }
828 }
829
830 fn diff_ctes(&mut self, source: &[Cte], target: &[Cte]) {
831 let source_map: HashMap<&str, &Cte> = source.iter().map(|c| (c.name.as_str(), c)).collect();
833 let target_map: HashMap<&str, &Cte> = target.iter().map(|c| (c.name.as_str(), c)).collect();
834
835 for (name, sc) in &source_map {
836 if let Some(tc) = target_map.get(name) {
837 if sc == tc {
838 self.changes.push(ChangeAction::Keep(
839 AstNode::Cte(Box::new((*sc).clone())),
840 AstNode::Cte(Box::new((*tc).clone())),
841 ));
842 } else {
843 self.diff_statements(&sc.query, &tc.query);
844 }
845 } else {
846 self.changes
847 .push(ChangeAction::Remove(AstNode::Cte(Box::new((*sc).clone()))));
848 }
849 }
850 for (name, tc) in &target_map {
851 if !source_map.contains_key(name) {
852 self.changes
853 .push(ChangeAction::Insert(AstNode::Cte(Box::new((*tc).clone()))));
854 }
855 }
856 }
857
858 fn diff_joins(&mut self, source: &[JoinClause], target: &[JoinClause]) {
859 let min_len = source.len().min(target.len());
860 for i in 0..min_len {
861 if source[i] == target[i] {
862 self.changes.push(ChangeAction::Keep(
863 AstNode::JoinClause(source[i].clone()),
864 AstNode::JoinClause(target[i].clone()),
865 ));
866 } else if source[i].join_type == target[i].join_type {
867 self.diff_table_sources(&source[i].table, &target[i].table);
869 self.diff_optional_exprs(&source[i].on, &target[i].on);
870 } else {
871 self.changes.push(ChangeAction::Update(
872 AstNode::JoinClause(source[i].clone()),
873 AstNode::JoinClause(target[i].clone()),
874 ));
875 }
876 }
877 for item in source.iter().skip(min_len) {
878 self.changes
879 .push(ChangeAction::Remove(AstNode::JoinClause(item.clone())));
880 }
881 for item in target.iter().skip(min_len) {
882 self.changes
883 .push(ChangeAction::Insert(AstNode::JoinClause(item.clone())));
884 }
885 }
886
887 fn diff_order_by(&mut self, source: &[OrderByItem], target: &[OrderByItem]) {
888 let min_len = source.len().min(target.len());
889 for i in 0..min_len {
890 if source[i] == target[i] {
891 self.changes.push(ChangeAction::Keep(
892 AstNode::OrderByItem(source[i].clone()),
893 AstNode::OrderByItem(target[i].clone()),
894 ));
895 } else if source[i].ascending == target[i].ascending
896 && source[i].nulls_first == target[i].nulls_first
897 {
898 self.diff_exprs(&source[i].expr, &target[i].expr);
899 } else {
900 self.changes.push(ChangeAction::Update(
901 AstNode::OrderByItem(source[i].clone()),
902 AstNode::OrderByItem(target[i].clone()),
903 ));
904 }
905 }
906 for item in source.iter().skip(min_len) {
907 self.changes
908 .push(ChangeAction::Remove(AstNode::OrderByItem(item.clone())));
909 }
910 for item in target.iter().skip(min_len) {
911 self.changes
912 .push(ChangeAction::Insert(AstNode::OrderByItem(item.clone())));
913 }
914 }
915
916 fn diff_table_sources(&mut self, source: &TableSource, target: &TableSource) {
917 if source == target {
918 return;
919 }
920 match (source, target) {
921 (TableSource::Table(st), TableSource::Table(tt)) => {
922 if st != tt {
923 self.changes.push(ChangeAction::Update(
924 AstNode::Expr(table_ref_to_expr(st)),
925 AstNode::Expr(table_ref_to_expr(tt)),
926 ));
927 }
928 }
929 (TableSource::Subquery { query: sq, .. }, TableSource::Subquery { query: tq, .. }) => {
930 self.diff_statements(sq, tq);
931 }
932 _ => {
933 self.remove_table_source(source);
935 self.insert_table_source(target);
936 }
937 }
938 }
939
940 fn insert_table_source(&mut self, source: &TableSource) {
941 match source {
942 TableSource::Table(t) => {
943 self.changes
944 .push(ChangeAction::Insert(AstNode::Expr(table_ref_to_expr(t))));
945 }
946 TableSource::Subquery { query, .. } => {
947 self.changes
948 .push(ChangeAction::Insert(AstNode::Statement(Box::new(
949 (**query).clone(),
950 ))));
951 }
952 other => {
953 self.changes
954 .push(ChangeAction::Insert(AstNode::Expr(Expr::StringLiteral(
955 format!("{other:?}"),
956 ))));
957 }
958 }
959 }
960
961 fn remove_table_source(&mut self, source: &TableSource) {
962 match source {
963 TableSource::Table(t) => {
964 self.changes
965 .push(ChangeAction::Remove(AstNode::Expr(table_ref_to_expr(t))));
966 }
967 TableSource::Subquery { query, .. } => {
968 self.changes
969 .push(ChangeAction::Remove(AstNode::Statement(Box::new(
970 (**query).clone(),
971 ))));
972 }
973 other => {
974 self.changes
975 .push(ChangeAction::Remove(AstNode::Expr(Expr::StringLiteral(
976 format!("{other:?}"),
977 ))));
978 }
979 }
980 }
981
982 fn diff_constraints(&mut self, source: &[TableConstraint], target: &[TableConstraint]) {
983 let min_len = source.len().min(target.len());
985 for i in 0..min_len {
986 if source[i] == target[i] {
987 self.changes.push(ChangeAction::Keep(
988 AstNode::TableConstraint(source[i].clone()),
989 AstNode::TableConstraint(target[i].clone()),
990 ));
991 } else {
992 self.changes.push(ChangeAction::Update(
993 AstNode::TableConstraint(source[i].clone()),
994 AstNode::TableConstraint(target[i].clone()),
995 ));
996 }
997 }
998 for item in source.iter().skip(min_len) {
999 self.changes
1000 .push(ChangeAction::Remove(AstNode::TableConstraint(item.clone())));
1001 }
1002 for item in target.iter().skip(min_len) {
1003 self.changes
1004 .push(ChangeAction::Insert(AstNode::TableConstraint(item.clone())));
1005 }
1006 }
1007
1008 fn diff_string_lists(&mut self, source: &[String], target: &[String]) {
1009 for s in source {
1010 if !target.contains(s) {
1011 self.changes
1012 .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
1013 table: None,
1014 name: s.clone(),
1015 quote_style: QuoteStyle::None,
1016 table_quote_style: QuoteStyle::None,
1017 })));
1018 }
1019 }
1020 for t in target {
1021 if !source.contains(t) {
1022 self.changes
1023 .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
1024 table: None,
1025 name: t.clone(),
1026 quote_style: QuoteStyle::None,
1027 table_quote_style: QuoteStyle::None,
1028 })));
1029 }
1030 }
1031 }
1032}
1033
1034fn compute_lcs(source: &[Expr], target: &[Expr]) -> Vec<(usize, usize)> {
1041 let m = source.len();
1042 let n = target.len();
1043 if m == 0 || n == 0 {
1044 return Vec::new();
1045 }
1046
1047 let mut dp = vec![vec![0u32; n + 1]; m + 1];
1049 for i in 1..=m {
1050 for j in 1..=n {
1051 if source[i - 1] == target[j - 1] {
1052 dp[i][j] = dp[i - 1][j - 1] + 1;
1053 } else {
1054 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
1055 }
1056 }
1057 }
1058
1059 let mut result = Vec::new();
1061 let mut i = m;
1062 let mut j = n;
1063 while i > 0 && j > 0 {
1064 if source[i - 1] == target[j - 1] {
1065 result.push((i - 1, j - 1));
1066 i -= 1;
1067 j -= 1;
1068 } else if dp[i - 1][j] >= dp[i][j - 1] {
1069 i -= 1;
1070 } else {
1071 j -= 1;
1072 }
1073 }
1074 result.reverse();
1075 result
1076}
1077
1078fn table_ref_to_expr(table: &TableRef) -> Expr {
1080 let full_name = match (&table.catalog, &table.schema) {
1081 (Some(c), Some(s)) => format!("{c}.{s}.{}", table.name),
1082 (None, Some(s)) => format!("{s}.{}", table.name),
1083 _ => table.name.clone(),
1084 };
1085 Expr::Column {
1086 table: table.schema.clone(),
1087 name: full_name,
1088 quote_style: table.name_quote_style,
1089 table_quote_style: QuoteStyle::None,
1090 }
1091}
1092
1093pub fn diff_sql(
1104 source_sql: &str,
1105 target_sql: &str,
1106 dialect: crate::dialects::Dialect,
1107) -> crate::errors::Result<Vec<ChangeAction>> {
1108 let source = crate::parser::parse(source_sql, dialect)?;
1109 let target = crate::parser::parse(target_sql, dialect)?;
1110 Ok(diff(&source, &target))
1111}
1112
1113#[cfg(test)]
1114mod tests {
1115 use super::*;
1116 use crate::dialects::Dialect;
1117 use crate::parser::parse;
1118
1119 fn count_by_action(changes: &[ChangeAction]) -> (usize, usize, usize, usize, usize) {
1120 let mut keeps = 0;
1121 let mut inserts = 0;
1122 let mut removes = 0;
1123 let mut updates = 0;
1124 let mut moves = 0;
1125 for c in changes {
1126 match c {
1127 ChangeAction::Keep(..) => keeps += 1,
1128 ChangeAction::Insert(..) => inserts += 1,
1129 ChangeAction::Remove(..) => removes += 1,
1130 ChangeAction::Update(..) => updates += 1,
1131 ChangeAction::Move(..) => moves += 1,
1132 }
1133 }
1134 (keeps, inserts, removes, updates, moves)
1135 }
1136
1137 #[test]
1138 fn test_identical_queries_are_all_keep() {
1139 let sql = "SELECT a, b FROM t WHERE a > 1";
1140 let source = parse(sql, Dialect::Ansi).unwrap();
1141 let target = parse(sql, Dialect::Ansi).unwrap();
1142 let changes = diff(&source, &target);
1143 let (keeps, inserts, removes, updates, _moves) = count_by_action(&changes);
1144 assert!(keeps > 0, "should have keep actions");
1145 assert_eq!(inserts, 0, "no inserts for identical queries");
1146 assert_eq!(removes, 0, "no removes for identical queries");
1147 assert_eq!(updates, 0, "no updates for identical queries");
1148 }
1149
1150 #[test]
1151 fn test_column_added() {
1152 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1153 let target = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1154 let changes = diff(&source, &target);
1155 let (keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1156 assert!(keeps > 0);
1157 assert!(inserts > 0, "should have insert for new column b");
1158 assert_eq!(removes, 0);
1159 }
1160
1161 #[test]
1162 fn test_column_removed() {
1163 let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1164 let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1165 let changes = diff(&source, &target);
1166 let (keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1167 assert!(keeps > 0);
1168 assert!(removes > 0, "should have remove for column b");
1169 }
1170
1171 #[test]
1172 fn test_column_changed() {
1173 let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1174 let target = parse("SELECT a, c FROM t", Dialect::Ansi).unwrap();
1175 let changes = diff(&source, &target);
1176 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1177 assert!(updates > 0, "should have update for b -> c");
1178 }
1179
1180 #[test]
1181 fn test_where_clause_added() {
1182 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1183 let target = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1184 let changes = diff(&source, &target);
1185 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1186 assert!(inserts > 0, "should have insert for WHERE clause");
1187 }
1188
1189 #[test]
1190 fn test_where_clause_removed() {
1191 let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1192 let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1193 let changes = diff(&source, &target);
1194 let (_keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1195 assert!(removes > 0, "should have remove for WHERE clause");
1196 }
1197
1198 #[test]
1199 fn test_where_clause_updated() {
1200 let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1201 let target = parse("SELECT a FROM t WHERE a > 2", Dialect::Ansi).unwrap();
1202 let changes = diff(&source, &target);
1203 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1204 assert!(updates > 0, "should have update for WHERE value change");
1205 }
1206
1207 #[test]
1208 fn test_table_changed() {
1209 let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1210 let target = parse("SELECT a FROM t2", Dialect::Ansi).unwrap();
1211 let changes = diff(&source, &target);
1212 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1213 assert!(updates > 0, "should have update for table change");
1214 }
1215
1216 #[test]
1217 fn test_join_added() {
1218 let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1219 let target = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
1220 let changes = diff(&source, &target);
1221 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1222 assert!(inserts > 0, "should have insert for JOIN");
1223 }
1224
1225 #[test]
1226 fn test_order_by_changed() {
1227 let source = parse("SELECT a FROM t ORDER BY a ASC", Dialect::Ansi).unwrap();
1228 let target = parse("SELECT a FROM t ORDER BY a DESC", Dialect::Ansi).unwrap();
1229 let changes = diff(&source, &target);
1230 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1231 assert!(updates > 0, "should have update for ORDER BY direction");
1232 }
1233
1234 #[test]
1235 fn test_complex_nested_query() {
1236 let source = parse(
1237 "SELECT a, b FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 0)",
1238 Dialect::Ansi,
1239 )
1240 .unwrap();
1241 let target = parse(
1242 "SELECT a, c FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 5)",
1243 Dialect::Ansi,
1244 )
1245 .unwrap();
1246 let changes = diff(&source, &target);
1247 let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1248 assert!(keeps > 0, "unchanged parts should be kept");
1249 assert!(updates > 0, "changed parts should be updated (b->c, 0->5)");
1250 }
1251
1252 #[test]
1253 fn test_different_statement_types() {
1254 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1255 let target = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
1256 let changes = diff(&source, &target);
1257 let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1258 assert!(removes > 0, "source should be removed");
1259 assert!(inserts > 0, "target should be inserted");
1260 }
1261
1262 #[test]
1263 fn test_cte_added() {
1264 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1265 let target = parse("WITH cte AS (SELECT 1 AS x) SELECT a FROM t", Dialect::Ansi).unwrap();
1266 let changes = diff(&source, &target);
1267 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1268 assert!(inserts > 0, "should have insert for CTE");
1269 }
1270
1271 #[test]
1272 fn test_limit_changed() {
1273 let source = parse("SELECT a FROM t LIMIT 10", Dialect::Ansi).unwrap();
1274 let target = parse("SELECT a FROM t LIMIT 20", Dialect::Ansi).unwrap();
1275 let changes = diff(&source, &target);
1276 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1277 assert!(updates > 0, "should have update for LIMIT change");
1278 }
1279
1280 #[test]
1281 fn test_group_by_added() {
1282 let source = parse("SELECT a, COUNT(*) FROM t", Dialect::Ansi).unwrap();
1283 let target = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1284 let changes = diff(&source, &target);
1285 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1286 assert!(inserts > 0, "should have insert for GROUP BY");
1287 }
1288
1289 #[test]
1290 fn test_diff_sql_convenience() {
1291 let changes = diff_sql("SELECT a FROM t", "SELECT a, b FROM t", Dialect::Ansi).unwrap();
1292 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1293 assert!(inserts > 0);
1294 }
1295
1296 #[test]
1297 fn test_having_added() {
1298 let source = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1299 let target = parse(
1300 "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
1301 Dialect::Ansi,
1302 )
1303 .unwrap();
1304 let changes = diff(&source, &target);
1305 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1306 assert!(inserts > 0, "should have insert for HAVING");
1307 }
1308
1309 #[test]
1310 fn test_create_table_column_diff() {
1311 let source = parse("CREATE TABLE t (a INT, b TEXT)", Dialect::Ansi).unwrap();
1312 let target = parse("CREATE TABLE t (a INT, c TEXT)", Dialect::Ansi).unwrap();
1313 let changes = diff(&source, &target);
1314 let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1315 assert!(removes > 0, "should remove column b");
1316 assert!(inserts > 0, "should insert column c");
1317 }
1318
1319 #[test]
1320 fn test_union_diff() {
1321 let source = parse("SELECT a FROM t1 UNION SELECT b FROM t2", Dialect::Ansi).unwrap();
1322 let target = parse("SELECT a FROM t1 UNION SELECT c FROM t2", Dialect::Ansi).unwrap();
1323 let changes = diff(&source, &target);
1324 let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1325 assert!(keeps > 0);
1326 assert!(updates > 0, "should have update for b -> c");
1327 }
1328}