Skip to main content

sqlglot_rust/diff/
mod.rs

1//! AST Diff — semantic comparison of SQL expression trees.
2//!
3//! Implements a tree edit distance algorithm inspired by the Change Distiller
4//! approach used in Python sqlglot's `diff.py`. Computes a sequence of
5//! [`ChangeAction`]s that transform one AST into another.
6//!
7//! # Example
8//!
9//! ```rust
10//! use sqlglot_rust::{parse, Dialect};
11//! use sqlglot_rust::diff::{diff, ChangeAction};
12//!
13//! let source = parse("SELECT a, b FROM t WHERE a > 1", Dialect::Ansi).unwrap();
14//! let target = parse("SELECT a, c FROM t WHERE a > 2", Dialect::Ansi).unwrap();
15//! let changes = diff(&source, &target);
16//!
17//! for change in &changes {
18//!     println!("{change:?}");
19//! }
20//! ```
21
22use std::collections::HashMap;
23
24use crate::ast::*;
25
26/// A change action describing a single difference between two ASTs.
27#[derive(Debug, Clone, PartialEq)]
28pub enum ChangeAction {
29    /// A node present in `source` that was removed.
30    Remove(AstNode),
31    /// A node inserted into `target` that was not in `source`.
32    Insert(AstNode),
33    /// A node that is structurally identical in both trees.
34    Keep(AstNode, AstNode),
35    /// A node that was moved to a different position in the tree.
36    Move(AstNode, AstNode),
37    /// A node in `source` that was replaced by a different node in `target`.
38    Update(AstNode, AstNode),
39}
40
41/// A wrapper around an AST node that can represent either statements or
42/// expressions, enabling uniform diff output.
43#[derive(Debug, Clone, PartialEq)]
44pub enum AstNode {
45    Statement(Box<Statement>),
46    Expr(Expr),
47    SelectItem(SelectItem),
48    JoinClause(JoinClause),
49    OrderByItem(OrderByItem),
50    Cte(Box<Cte>),
51    ColumnDef(ColumnDef),
52    TableConstraint(TableConstraint),
53}
54
55impl std::fmt::Display for AstNode {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            AstNode::Statement(s) => write!(f, "{s:?}"),
59            AstNode::Expr(e) => write!(f, "{e:?}"),
60            AstNode::SelectItem(si) => write!(f, "{si:?}"),
61            AstNode::JoinClause(j) => write!(f, "{j:?}"),
62            AstNode::OrderByItem(o) => write!(f, "{o:?}"),
63            AstNode::Cte(c) => write!(f, "{c:?}"),
64            AstNode::ColumnDef(cd) => write!(f, "{cd:?}"),
65            AstNode::TableConstraint(tc) => write!(f, "{tc:?}"),
66        }
67    }
68}
69
70/// Compute the semantic diff between two SQL statements.
71///
72/// Returns a list of [`ChangeAction`]s describing the minimal set of changes
73/// needed to transform `source` into `target`.
74#[must_use]
75pub fn diff(source: &Statement, target: &Statement) -> Vec<ChangeAction> {
76    let mut differ = AstDiffer::new();
77    differ.diff_statements(source, target);
78    differ.changes
79}
80
81/// Internal differ state that accumulates change actions.
82struct AstDiffer {
83    changes: Vec<ChangeAction>,
84}
85
86impl AstDiffer {
87    fn new() -> Self {
88        Self {
89            changes: Vec::new(),
90        }
91    }
92
93    fn diff_statements(&mut self, source: &Statement, target: &Statement) {
94        use Statement::*;
95
96        match (source, target) {
97            (Select(s), Select(t)) => self.diff_select(s, t),
98            (Insert(s), Insert(t)) => self.diff_insert(s, t),
99            (Update(s), Update(t)) => self.diff_update(s, t),
100            (Delete(s), Delete(t)) => self.diff_delete(s, t),
101            (CreateTable(s), CreateTable(t)) => self.diff_create_table(s, t),
102            (DropTable(s), DropTable(t)) => self.diff_drop_table(s, t),
103            (SetOperation(s), SetOperation(t)) => self.diff_set_operation(s, t),
104            (AlterTable(s), AlterTable(t)) => self.diff_alter_table(s, t),
105            (CreateView(s), CreateView(t)) => self.diff_create_view(s, t),
106            (Expression(s), Expression(t)) => self.diff_exprs(s, t),
107            _ => {
108                // Different statement types → remove old, insert new
109                self.changes
110                    .push(ChangeAction::Remove(AstNode::Statement(Box::new(
111                        source.clone(),
112                    ))));
113                self.changes
114                    .push(ChangeAction::Insert(AstNode::Statement(Box::new(
115                        target.clone(),
116                    ))));
117            }
118        }
119    }
120
121    // ── SELECT ─────────────────────────────────────────────────────────
122
123    fn diff_select(&mut self, source: &SelectStatement, target: &SelectStatement) {
124        // CTEs
125        self.diff_ctes(&source.ctes, &target.ctes);
126
127        // DISTINCT
128        if source.distinct != target.distinct {
129            if target.distinct {
130                self.changes
131                    .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
132                        table: None,
133                        name: "DISTINCT".to_string(),
134                        quote_style: QuoteStyle::None,
135                        table_quote_style: QuoteStyle::None,
136                    })));
137            } else {
138                self.changes
139                    .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
140                        table: None,
141                        name: "DISTINCT".to_string(),
142                        quote_style: QuoteStyle::None,
143                        table_quote_style: QuoteStyle::None,
144                    })));
145            }
146        }
147
148        // SELECT columns (ordered)
149        self.diff_select_items(&source.columns, &target.columns);
150
151        // FROM
152        match (&source.from, &target.from) {
153            (Some(sf), Some(tf)) => self.diff_table_sources(&sf.source, &tf.source),
154            (None, Some(tf)) => self.insert_table_source(&tf.source),
155            (Some(sf), None) => self.remove_table_source(&sf.source),
156            (None, None) => {}
157        }
158
159        // JOINs
160        self.diff_joins(&source.joins, &target.joins);
161
162        // WHERE
163        self.diff_optional_exprs(&source.where_clause, &target.where_clause);
164
165        // GROUP BY
166        self.diff_expr_lists(&source.group_by, &target.group_by);
167
168        // HAVING
169        self.diff_optional_exprs(&source.having, &target.having);
170
171        // ORDER BY
172        self.diff_order_by(&source.order_by, &target.order_by);
173
174        // LIMIT
175        self.diff_optional_exprs(&source.limit, &target.limit);
176
177        // OFFSET
178        self.diff_optional_exprs(&source.offset, &target.offset);
179
180        // QUALIFY
181        self.diff_optional_exprs(&source.qualify, &target.qualify);
182    }
183
184    // ── INSERT ─────────────────────────────────────────────────────────
185
186    fn diff_insert(&mut self, source: &InsertStatement, target: &InsertStatement) {
187        if source.table != target.table {
188            self.changes.push(ChangeAction::Update(
189                AstNode::Expr(table_ref_to_expr(&source.table)),
190                AstNode::Expr(table_ref_to_expr(&target.table)),
191            ));
192        }
193
194        // Column list
195        self.diff_string_lists(&source.columns, &target.columns);
196
197        // Source
198        match (&source.source, &target.source) {
199            (InsertSource::Values(sv), InsertSource::Values(tv)) => {
200                for (i, (sr, tr)) in sv.iter().zip(tv.iter()).enumerate() {
201                    self.diff_expr_lists(sr, tr);
202                    let _ = i;
203                }
204                for extra in sv.iter().skip(tv.len()) {
205                    for e in extra {
206                        self.changes
207                            .push(ChangeAction::Remove(AstNode::Expr(e.clone())));
208                    }
209                }
210                for extra in tv.iter().skip(sv.len()) {
211                    for e in extra {
212                        self.changes
213                            .push(ChangeAction::Insert(AstNode::Expr(e.clone())));
214                    }
215                }
216            }
217            (InsertSource::Query(sq), InsertSource::Query(tq)) => {
218                self.diff_statements(sq, tq);
219            }
220            _ => {
221                self.changes
222                    .push(ChangeAction::Remove(AstNode::Statement(Box::new(
223                        Statement::Insert(source.clone()),
224                    ))));
225                self.changes
226                    .push(ChangeAction::Insert(AstNode::Statement(Box::new(
227                        Statement::Insert(target.clone()),
228                    ))));
229            }
230        }
231    }
232
233    // ── UPDATE ─────────────────────────────────────────────────────────
234
235    fn diff_update(&mut self, source: &UpdateStatement, target: &UpdateStatement) {
236        if source.table != target.table {
237            self.changes.push(ChangeAction::Update(
238                AstNode::Expr(table_ref_to_expr(&source.table)),
239                AstNode::Expr(table_ref_to_expr(&target.table)),
240            ));
241        }
242
243        // Assignments (ordered by column name matching)
244        let source_map: HashMap<&str, &Expr> = source
245            .assignments
246            .iter()
247            .map(|(k, v)| (k.as_str(), v))
248            .collect();
249        let target_map: HashMap<&str, &Expr> = target
250            .assignments
251            .iter()
252            .map(|(k, v)| (k.as_str(), v))
253            .collect();
254
255        for (col, src_val) in &source_map {
256            if let Some(tgt_val) = target_map.get(col) {
257                self.diff_exprs(src_val, tgt_val);
258            } else {
259                self.changes
260                    .push(ChangeAction::Remove(AstNode::Expr((*src_val).clone())));
261            }
262        }
263        for (col, tgt_val) in &target_map {
264            if !source_map.contains_key(col) {
265                self.changes
266                    .push(ChangeAction::Insert(AstNode::Expr((*tgt_val).clone())));
267            }
268        }
269
270        self.diff_optional_exprs(&source.where_clause, &target.where_clause);
271    }
272
273    // ── DELETE ─────────────────────────────────────────────────────────
274
275    fn diff_delete(&mut self, source: &DeleteStatement, target: &DeleteStatement) {
276        if source.table != target.table {
277            self.changes.push(ChangeAction::Update(
278                AstNode::Expr(table_ref_to_expr(&source.table)),
279                AstNode::Expr(table_ref_to_expr(&target.table)),
280            ));
281        }
282        self.diff_optional_exprs(&source.where_clause, &target.where_clause);
283    }
284
285    // ── CREATE TABLE ───────────────────────────────────────────────────
286
287    fn diff_create_table(&mut self, source: &CreateTableStatement, target: &CreateTableStatement) {
288        if source.table != target.table {
289            self.changes.push(ChangeAction::Update(
290                AstNode::Expr(table_ref_to_expr(&source.table)),
291                AstNode::Expr(table_ref_to_expr(&target.table)),
292            ));
293        }
294
295        // Column definitions (match by name)
296        let source_cols: HashMap<&str, &ColumnDef> = source
297            .columns
298            .iter()
299            .map(|c| (c.name.as_str(), c))
300            .collect();
301        let target_cols: HashMap<&str, &ColumnDef> = target
302            .columns
303            .iter()
304            .map(|c| (c.name.as_str(), c))
305            .collect();
306
307        for (name, src_col) in &source_cols {
308            if let Some(tgt_col) = target_cols.get(name) {
309                if src_col != tgt_col {
310                    self.changes.push(ChangeAction::Update(
311                        AstNode::ColumnDef((*src_col).clone()),
312                        AstNode::ColumnDef((*tgt_col).clone()),
313                    ));
314                } else {
315                    self.changes.push(ChangeAction::Keep(
316                        AstNode::ColumnDef((*src_col).clone()),
317                        AstNode::ColumnDef((*tgt_col).clone()),
318                    ));
319                }
320            } else {
321                self.changes
322                    .push(ChangeAction::Remove(AstNode::ColumnDef((*src_col).clone())));
323            }
324        }
325        for (name, tgt_col) in &target_cols {
326            if !source_cols.contains_key(name) {
327                self.changes
328                    .push(ChangeAction::Insert(AstNode::ColumnDef((*tgt_col).clone())));
329            }
330        }
331
332        // Constraints
333        self.diff_constraints(&source.constraints, &target.constraints);
334    }
335
336    // ── DROP TABLE ─────────────────────────────────────────────────────
337
338    fn diff_drop_table(&mut self, source: &DropTableStatement, target: &DropTableStatement) {
339        if source != target {
340            self.changes.push(ChangeAction::Update(
341                AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
342                AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
343            ));
344        } else {
345            self.changes.push(ChangeAction::Keep(
346                AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
347                AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
348            ));
349        }
350    }
351
352    // ── SET OPERATION ──────────────────────────────────────────────────
353
354    fn diff_set_operation(
355        &mut self,
356        source: &SetOperationStatement,
357        target: &SetOperationStatement,
358    ) {
359        if source.op != target.op || source.all != target.all {
360            self.changes.push(ChangeAction::Update(
361                AstNode::Statement(Box::new(Statement::SetOperation(source.clone()))),
362                AstNode::Statement(Box::new(Statement::SetOperation(target.clone()))),
363            ));
364            return;
365        }
366        self.diff_statements(&source.left, &target.left);
367        self.diff_statements(&source.right, &target.right);
368        self.diff_order_by(&source.order_by, &target.order_by);
369        self.diff_optional_exprs(&source.limit, &target.limit);
370        self.diff_optional_exprs(&source.offset, &target.offset);
371    }
372
373    // ── ALTER TABLE ────────────────────────────────────────────────────
374
375    fn diff_alter_table(&mut self, source: &AlterTableStatement, target: &AlterTableStatement) {
376        if source.table != target.table {
377            self.changes.push(ChangeAction::Update(
378                AstNode::Expr(table_ref_to_expr(&source.table)),
379                AstNode::Expr(table_ref_to_expr(&target.table)),
380            ));
381        }
382        // Actions compared for equality
383        if source.actions != target.actions {
384            self.changes.push(ChangeAction::Update(
385                AstNode::Statement(Box::new(Statement::AlterTable(source.clone()))),
386                AstNode::Statement(Box::new(Statement::AlterTable(target.clone()))),
387            ));
388        }
389    }
390
391    // ── CREATE VIEW ────────────────────────────────────────────────────
392
393    fn diff_create_view(&mut self, source: &CreateViewStatement, target: &CreateViewStatement) {
394        if source.name != target.name {
395            self.changes.push(ChangeAction::Update(
396                AstNode::Expr(table_ref_to_expr(&source.name)),
397                AstNode::Expr(table_ref_to_expr(&target.name)),
398            ));
399        }
400        self.diff_statements(&source.query, &target.query);
401    }
402
403    // ── Shared helpers ─────────────────────────────────────────────────
404
405    fn diff_exprs(&mut self, source: &Expr, target: &Expr) {
406        if source == target {
407            self.changes.push(ChangeAction::Keep(
408                AstNode::Expr(source.clone()),
409                AstNode::Expr(target.clone()),
410            ));
411            return;
412        }
413
414        // Same top-level variant → recurse into children
415        match (source, target) {
416            (
417                Expr::BinaryOp {
418                    left: sl,
419                    op: sop,
420                    right: sr,
421                },
422                Expr::BinaryOp {
423                    left: tl,
424                    op: top,
425                    right: tr,
426                },
427            ) => {
428                if sop != top {
429                    self.changes.push(ChangeAction::Update(
430                        AstNode::Expr(source.clone()),
431                        AstNode::Expr(target.clone()),
432                    ));
433                } else {
434                    self.diff_exprs(sl, tl);
435                    self.diff_exprs(sr, tr);
436                }
437            }
438            (Expr::UnaryOp { op: sop, expr: se }, Expr::UnaryOp { op: top, expr: te }) => {
439                if sop != top {
440                    self.changes.push(ChangeAction::Update(
441                        AstNode::Expr(source.clone()),
442                        AstNode::Expr(target.clone()),
443                    ));
444                } else {
445                    self.diff_exprs(se, te);
446                }
447            }
448            (
449                Expr::Function {
450                    name: sn,
451                    args: sa,
452                    distinct: sd,
453                    ..
454                },
455                Expr::Function {
456                    name: tn,
457                    args: ta,
458                    distinct: td,
459                    ..
460                },
461            ) => {
462                if sn != tn || sd != td {
463                    self.changes.push(ChangeAction::Update(
464                        AstNode::Expr(source.clone()),
465                        AstNode::Expr(target.clone()),
466                    ));
467                } else {
468                    self.diff_expr_lists(sa, ta);
469                }
470            }
471            (
472                Expr::Cast {
473                    expr: se,
474                    data_type: sd,
475                },
476                Expr::Cast {
477                    expr: te,
478                    data_type: td,
479                },
480            ) => {
481                if sd != td {
482                    self.changes.push(ChangeAction::Update(
483                        AstNode::Expr(source.clone()),
484                        AstNode::Expr(target.clone()),
485                    ));
486                } else {
487                    self.diff_exprs(se, te);
488                }
489            }
490            (
491                Expr::Case {
492                    operand: so,
493                    when_clauses: sw,
494                    else_clause: se,
495                },
496                Expr::Case {
497                    operand: to,
498                    when_clauses: tw,
499                    else_clause: te,
500                },
501            ) => {
502                self.diff_optional_boxed_exprs(so, to);
503                // when clauses — ordered
504                for (i, ((sc, sr), (tc, tr))) in sw.iter().zip(tw.iter()).enumerate() {
505                    self.diff_exprs(sc, tc);
506                    self.diff_exprs(sr, tr);
507                    let _ = i;
508                }
509                for (sc, sr) in sw.iter().skip(tw.len()) {
510                    self.changes
511                        .push(ChangeAction::Remove(AstNode::Expr(sc.clone())));
512                    self.changes
513                        .push(ChangeAction::Remove(AstNode::Expr(sr.clone())));
514                }
515                for (tc, tr) in tw.iter().skip(sw.len()) {
516                    self.changes
517                        .push(ChangeAction::Insert(AstNode::Expr(tc.clone())));
518                    self.changes
519                        .push(ChangeAction::Insert(AstNode::Expr(tr.clone())));
520                }
521                self.diff_optional_boxed_exprs(se, te);
522            }
523            (Expr::Nested(se), Expr::Nested(te)) => self.diff_exprs(se, te),
524            (
525                Expr::Between {
526                    expr: se,
527                    low: sl,
528                    high: sh,
529                    negated: sn,
530                },
531                Expr::Between {
532                    expr: te,
533                    low: tl,
534                    high: th,
535                    negated: tn,
536                },
537            ) => {
538                if sn != tn {
539                    self.changes.push(ChangeAction::Update(
540                        AstNode::Expr(source.clone()),
541                        AstNode::Expr(target.clone()),
542                    ));
543                } else {
544                    self.diff_exprs(se, te);
545                    self.diff_exprs(sl, tl);
546                    self.diff_exprs(sh, th);
547                }
548            }
549            (
550                Expr::InList {
551                    expr: se,
552                    list: sl,
553                    negated: sn,
554                },
555                Expr::InList {
556                    expr: te,
557                    list: tl,
558                    negated: tn,
559                },
560            ) => {
561                if sn != tn {
562                    self.changes.push(ChangeAction::Update(
563                        AstNode::Expr(source.clone()),
564                        AstNode::Expr(target.clone()),
565                    ));
566                } else {
567                    self.diff_exprs(se, te);
568                    self.diff_expr_lists(sl, tl);
569                }
570            }
571            (
572                Expr::InSubquery {
573                    expr: se,
574                    subquery: ss,
575                    negated: sn,
576                },
577                Expr::InSubquery {
578                    expr: te,
579                    subquery: ts,
580                    negated: tn,
581                },
582            ) => {
583                if sn != tn {
584                    self.changes.push(ChangeAction::Update(
585                        AstNode::Expr(source.clone()),
586                        AstNode::Expr(target.clone()),
587                    ));
588                } else {
589                    self.diff_exprs(se, te);
590                    self.diff_statements(ss, ts);
591                }
592            }
593            (
594                Expr::IsNull {
595                    expr: se,
596                    negated: sn,
597                },
598                Expr::IsNull {
599                    expr: te,
600                    negated: tn,
601                },
602            ) => {
603                if sn != tn {
604                    self.changes.push(ChangeAction::Update(
605                        AstNode::Expr(source.clone()),
606                        AstNode::Expr(target.clone()),
607                    ));
608                } else {
609                    self.diff_exprs(se, te);
610                }
611            }
612            (
613                Expr::Like {
614                    expr: se,
615                    pattern: sp,
616                    negated: sn,
617                    ..
618                },
619                Expr::Like {
620                    expr: te,
621                    pattern: tp,
622                    negated: tn,
623                    ..
624                },
625            )
626            | (
627                Expr::ILike {
628                    expr: se,
629                    pattern: sp,
630                    negated: sn,
631                    ..
632                },
633                Expr::ILike {
634                    expr: te,
635                    pattern: tp,
636                    negated: tn,
637                    ..
638                },
639            ) => {
640                if sn != tn {
641                    self.changes.push(ChangeAction::Update(
642                        AstNode::Expr(source.clone()),
643                        AstNode::Expr(target.clone()),
644                    ));
645                } else {
646                    self.diff_exprs(se, te);
647                    self.diff_exprs(sp, tp);
648                }
649            }
650            (Expr::Subquery(ss), Expr::Subquery(ts)) => self.diff_statements(ss, ts),
651            (
652                Expr::Exists {
653                    subquery: ss,
654                    negated: sn,
655                },
656                Expr::Exists {
657                    subquery: ts,
658                    negated: tn,
659                },
660            ) => {
661                if sn != tn {
662                    self.changes.push(ChangeAction::Update(
663                        AstNode::Expr(source.clone()),
664                        AstNode::Expr(target.clone()),
665                    ));
666                } else {
667                    self.diff_statements(ss, ts);
668                }
669            }
670            (Expr::Alias { expr: se, name: sn }, Expr::Alias { expr: te, name: tn }) => {
671                if sn != tn {
672                    self.changes.push(ChangeAction::Update(
673                        AstNode::Expr(source.clone()),
674                        AstNode::Expr(target.clone()),
675                    ));
676                } else {
677                    self.diff_exprs(se, te);
678                }
679            }
680            (Expr::Coalesce(sa), Expr::Coalesce(ta)) => self.diff_expr_lists(sa, ta),
681            (Expr::ArrayLiteral(sa), Expr::ArrayLiteral(ta)) => self.diff_expr_lists(sa, ta),
682            (Expr::Tuple(sa), Expr::Tuple(ta)) => self.diff_expr_lists(sa, ta),
683            (Expr::TypedFunction { func: sf, .. }, Expr::TypedFunction { func: tf, .. }) => {
684                if std::mem::discriminant(sf) == std::mem::discriminant(tf) && source == target {
685                    self.changes.push(ChangeAction::Keep(
686                        AstNode::Expr(source.clone()),
687                        AstNode::Expr(target.clone()),
688                    ));
689                } else {
690                    self.changes.push(ChangeAction::Update(
691                        AstNode::Expr(source.clone()),
692                        AstNode::Expr(target.clone()),
693                    ));
694                }
695            }
696            // Different variant types → leaf-level update
697            _ => {
698                self.changes.push(ChangeAction::Update(
699                    AstNode::Expr(source.clone()),
700                    AstNode::Expr(target.clone()),
701                ));
702            }
703        }
704    }
705
706    /// Diff two ordered expression lists (e.g., SELECT columns, function args).
707    fn diff_expr_lists(&mut self, source: &[Expr], target: &[Expr]) {
708        // Use longest common subsequence for ordered diff
709        let lcs = compute_lcs(source, target);
710        let mut si = 0;
711        let mut ti = 0;
712        let mut li = 0;
713
714        while si < source.len() || ti < target.len() {
715            if li < lcs.len() {
716                let (lcs_si, lcs_ti) = lcs[li];
717
718                // Remove items before the next LCS match in source
719                while si < lcs_si {
720                    self.changes
721                        .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
722                    si += 1;
723                }
724                // Insert items before the next LCS match in target
725                while ti < lcs_ti {
726                    self.changes
727                        .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
728                    ti += 1;
729                }
730                // Matched pair — recurse to find deeper changes
731                self.diff_exprs(&source[si], &target[ti]);
732                si += 1;
733                ti += 1;
734                li += 1;
735            } else {
736                // Remaining source items are removed
737                while si < source.len() {
738                    self.changes
739                        .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
740                    si += 1;
741                }
742                // Remaining target items are inserted
743                while ti < target.len() {
744                    self.changes
745                        .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
746                    ti += 1;
747                }
748            }
749        }
750    }
751
752    fn diff_select_items(&mut self, source: &[SelectItem], target: &[SelectItem]) {
753        let min_len = source.len().min(target.len());
754        for i in 0..min_len {
755            if source[i] == target[i] {
756                self.changes.push(ChangeAction::Keep(
757                    AstNode::SelectItem(source[i].clone()),
758                    AstNode::SelectItem(target[i].clone()),
759                ));
760            } else {
761                match (&source[i], &target[i]) {
762                    (
763                        SelectItem::Expr {
764                            expr: se,
765                            alias: sa,
766                            ..
767                        },
768                        SelectItem::Expr {
769                            expr: te,
770                            alias: ta,
771                            ..
772                        },
773                    ) => {
774                        if sa != ta {
775                            self.changes.push(ChangeAction::Update(
776                                AstNode::SelectItem(source[i].clone()),
777                                AstNode::SelectItem(target[i].clone()),
778                            ));
779                        } else {
780                            self.diff_exprs(se, te);
781                        }
782                    }
783                    _ => {
784                        self.changes.push(ChangeAction::Update(
785                            AstNode::SelectItem(source[i].clone()),
786                            AstNode::SelectItem(target[i].clone()),
787                        ));
788                    }
789                }
790            }
791        }
792        for item in source.iter().skip(min_len) {
793            self.changes
794                .push(ChangeAction::Remove(AstNode::SelectItem(item.clone())));
795        }
796        for item in target.iter().skip(min_len) {
797            self.changes
798                .push(ChangeAction::Insert(AstNode::SelectItem(item.clone())));
799        }
800    }
801
802    fn diff_optional_exprs(&mut self, source: &Option<Expr>, target: &Option<Expr>) {
803        match (source, target) {
804            (Some(s), Some(t)) => self.diff_exprs(s, t),
805            (None, Some(t)) => self
806                .changes
807                .push(ChangeAction::Insert(AstNode::Expr(t.clone()))),
808            (Some(s), None) => self
809                .changes
810                .push(ChangeAction::Remove(AstNode::Expr(s.clone()))),
811            (None, None) => {}
812        }
813    }
814
815    fn diff_optional_boxed_exprs(
816        &mut self,
817        source: &Option<Box<Expr>>,
818        target: &Option<Box<Expr>>,
819    ) {
820        match (source, target) {
821            (Some(s), Some(t)) => self.diff_exprs(s, t),
822            (None, Some(t)) => self
823                .changes
824                .push(ChangeAction::Insert(AstNode::Expr((**t).clone()))),
825            (Some(s), None) => self
826                .changes
827                .push(ChangeAction::Remove(AstNode::Expr((**s).clone()))),
828            (None, None) => {}
829        }
830    }
831
832    fn diff_ctes(&mut self, source: &[Cte], target: &[Cte]) {
833        // Match CTEs by name
834        let source_map: HashMap<&str, &Cte> = source.iter().map(|c| (c.name.as_str(), c)).collect();
835        let target_map: HashMap<&str, &Cte> = target.iter().map(|c| (c.name.as_str(), c)).collect();
836
837        for (name, sc) in &source_map {
838            if let Some(tc) = target_map.get(name) {
839                if sc == tc {
840                    self.changes.push(ChangeAction::Keep(
841                        AstNode::Cte(Box::new((*sc).clone())),
842                        AstNode::Cte(Box::new((*tc).clone())),
843                    ));
844                } else {
845                    self.diff_statements(&sc.query, &tc.query);
846                }
847            } else {
848                self.changes
849                    .push(ChangeAction::Remove(AstNode::Cte(Box::new((*sc).clone()))));
850            }
851        }
852        for (name, tc) in &target_map {
853            if !source_map.contains_key(name) {
854                self.changes
855                    .push(ChangeAction::Insert(AstNode::Cte(Box::new((*tc).clone()))));
856            }
857        }
858    }
859
860    fn diff_joins(&mut self, source: &[JoinClause], target: &[JoinClause]) {
861        let min_len = source.len().min(target.len());
862        for i in 0..min_len {
863            if source[i] == target[i] {
864                self.changes.push(ChangeAction::Keep(
865                    AstNode::JoinClause(source[i].clone()),
866                    AstNode::JoinClause(target[i].clone()),
867                ));
868            } else if source[i].join_type == target[i].join_type {
869                // Same join type, diff the contents
870                self.diff_table_sources(&source[i].table, &target[i].table);
871                self.diff_optional_exprs(&source[i].on, &target[i].on);
872            } else {
873                self.changes.push(ChangeAction::Update(
874                    AstNode::JoinClause(source[i].clone()),
875                    AstNode::JoinClause(target[i].clone()),
876                ));
877            }
878        }
879        for item in source.iter().skip(min_len) {
880            self.changes
881                .push(ChangeAction::Remove(AstNode::JoinClause(item.clone())));
882        }
883        for item in target.iter().skip(min_len) {
884            self.changes
885                .push(ChangeAction::Insert(AstNode::JoinClause(item.clone())));
886        }
887    }
888
889    fn diff_order_by(&mut self, source: &[OrderByItem], target: &[OrderByItem]) {
890        let min_len = source.len().min(target.len());
891        for i in 0..min_len {
892            if source[i] == target[i] {
893                self.changes.push(ChangeAction::Keep(
894                    AstNode::OrderByItem(source[i].clone()),
895                    AstNode::OrderByItem(target[i].clone()),
896                ));
897            } else if source[i].ascending == target[i].ascending
898                && source[i].nulls_first == target[i].nulls_first
899            {
900                self.diff_exprs(&source[i].expr, &target[i].expr);
901            } else {
902                self.changes.push(ChangeAction::Update(
903                    AstNode::OrderByItem(source[i].clone()),
904                    AstNode::OrderByItem(target[i].clone()),
905                ));
906            }
907        }
908        for item in source.iter().skip(min_len) {
909            self.changes
910                .push(ChangeAction::Remove(AstNode::OrderByItem(item.clone())));
911        }
912        for item in target.iter().skip(min_len) {
913            self.changes
914                .push(ChangeAction::Insert(AstNode::OrderByItem(item.clone())));
915        }
916    }
917
918    fn diff_table_sources(&mut self, source: &TableSource, target: &TableSource) {
919        if source == target {
920            return;
921        }
922        match (source, target) {
923            (TableSource::Table(st), TableSource::Table(tt)) => {
924                if st != tt {
925                    self.changes.push(ChangeAction::Update(
926                        AstNode::Expr(table_ref_to_expr(st)),
927                        AstNode::Expr(table_ref_to_expr(tt)),
928                    ));
929                }
930            }
931            (TableSource::Subquery { query: sq, .. }, TableSource::Subquery { query: tq, .. }) => {
932                self.diff_statements(sq, tq);
933            }
934            _ => {
935                // Different source types
936                self.remove_table_source(source);
937                self.insert_table_source(target);
938            }
939        }
940    }
941
942    fn insert_table_source(&mut self, source: &TableSource) {
943        match source {
944            TableSource::Table(t) => {
945                self.changes
946                    .push(ChangeAction::Insert(AstNode::Expr(table_ref_to_expr(t))));
947            }
948            TableSource::Subquery { query, .. } => {
949                self.changes
950                    .push(ChangeAction::Insert(AstNode::Statement(Box::new(
951                        (**query).clone(),
952                    ))));
953            }
954            other => {
955                self.changes
956                    .push(ChangeAction::Insert(AstNode::Expr(Expr::StringLiteral(
957                        format!("{other:?}"),
958                    ))));
959            }
960        }
961    }
962
963    fn remove_table_source(&mut self, source: &TableSource) {
964        match source {
965            TableSource::Table(t) => {
966                self.changes
967                    .push(ChangeAction::Remove(AstNode::Expr(table_ref_to_expr(t))));
968            }
969            TableSource::Subquery { query, .. } => {
970                self.changes
971                    .push(ChangeAction::Remove(AstNode::Statement(Box::new(
972                        (**query).clone(),
973                    ))));
974            }
975            other => {
976                self.changes
977                    .push(ChangeAction::Remove(AstNode::Expr(Expr::StringLiteral(
978                        format!("{other:?}"),
979                    ))));
980            }
981        }
982    }
983
984    fn diff_constraints(&mut self, source: &[TableConstraint], target: &[TableConstraint]) {
985        // Simple positional diff for constraints
986        let min_len = source.len().min(target.len());
987        for i in 0..min_len {
988            if source[i] == target[i] {
989                self.changes.push(ChangeAction::Keep(
990                    AstNode::TableConstraint(source[i].clone()),
991                    AstNode::TableConstraint(target[i].clone()),
992                ));
993            } else {
994                self.changes.push(ChangeAction::Update(
995                    AstNode::TableConstraint(source[i].clone()),
996                    AstNode::TableConstraint(target[i].clone()),
997                ));
998            }
999        }
1000        for item in source.iter().skip(min_len) {
1001            self.changes
1002                .push(ChangeAction::Remove(AstNode::TableConstraint(item.clone())));
1003        }
1004        for item in target.iter().skip(min_len) {
1005            self.changes
1006                .push(ChangeAction::Insert(AstNode::TableConstraint(item.clone())));
1007        }
1008    }
1009
1010    fn diff_string_lists(&mut self, source: &[String], target: &[String]) {
1011        for s in source {
1012            if !target.contains(s) {
1013                self.changes
1014                    .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
1015                        table: None,
1016                        name: s.clone(),
1017                        quote_style: QuoteStyle::None,
1018                        table_quote_style: QuoteStyle::None,
1019                    })));
1020            }
1021        }
1022        for t in target {
1023            if !source.contains(t) {
1024                self.changes
1025                    .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
1026                        table: None,
1027                        name: t.clone(),
1028                        quote_style: QuoteStyle::None,
1029                        table_quote_style: QuoteStyle::None,
1030                    })));
1031            }
1032        }
1033    }
1034}
1035
1036// ═══════════════════════════════════════════════════════════════════════
1037// LCS — Longest Common Subsequence for ordered diff
1038// ═══════════════════════════════════════════════════════════════════════
1039
1040/// Compute the longest common subsequence of two expression slices,
1041/// returning pairs of (source_index, target_index).
1042fn compute_lcs(source: &[Expr], target: &[Expr]) -> Vec<(usize, usize)> {
1043    let m = source.len();
1044    let n = target.len();
1045    if m == 0 || n == 0 {
1046        return Vec::new();
1047    }
1048
1049    // Build DP table
1050    let mut dp = vec![vec![0u32; n + 1]; m + 1];
1051    for i in 1..=m {
1052        for j in 1..=n {
1053            if source[i - 1] == target[j - 1] {
1054                dp[i][j] = dp[i - 1][j - 1] + 1;
1055            } else {
1056                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
1057            }
1058        }
1059    }
1060
1061    // Backtrack to find the actual subsequence indices
1062    let mut result = Vec::new();
1063    let mut i = m;
1064    let mut j = n;
1065    while i > 0 && j > 0 {
1066        if source[i - 1] == target[j - 1] {
1067            result.push((i - 1, j - 1));
1068            i -= 1;
1069            j -= 1;
1070        } else if dp[i - 1][j] >= dp[i][j - 1] {
1071            i -= 1;
1072        } else {
1073            j -= 1;
1074        }
1075    }
1076    result.reverse();
1077    result
1078}
1079
1080/// Convert a `TableRef` to an `Expr::Column` for uniform representation.
1081fn table_ref_to_expr(table: &TableRef) -> Expr {
1082    let full_name = match (&table.catalog, &table.schema) {
1083        (Some(c), Some(s)) => format!("{c}.{s}.{}", table.name),
1084        (None, Some(s)) => format!("{s}.{}", table.name),
1085        _ => table.name.clone(),
1086    };
1087    Expr::Column {
1088        table: table.schema.clone(),
1089        name: full_name,
1090        quote_style: table.name_quote_style,
1091        table_quote_style: QuoteStyle::None,
1092    }
1093}
1094
1095// ═══════════════════════════════════════════════════════════════════════
1096// Convenience: diff from SQL strings
1097// ═══════════════════════════════════════════════════════════════════════
1098
1099/// Parse two SQL strings and compute their diff.
1100///
1101/// # Errors
1102///
1103/// Returns a [`SqlglotError`](crate::errors::SqlglotError) if either
1104/// string fails to parse.
1105pub fn diff_sql(
1106    source_sql: &str,
1107    target_sql: &str,
1108    dialect: crate::dialects::Dialect,
1109) -> crate::errors::Result<Vec<ChangeAction>> {
1110    let source = crate::parser::parse(source_sql, dialect)?;
1111    let target = crate::parser::parse(target_sql, dialect)?;
1112    Ok(diff(&source, &target))
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118    use crate::dialects::Dialect;
1119    use crate::parser::parse;
1120
1121    fn count_by_action(changes: &[ChangeAction]) -> (usize, usize, usize, usize, usize) {
1122        let mut keeps = 0;
1123        let mut inserts = 0;
1124        let mut removes = 0;
1125        let mut updates = 0;
1126        let mut moves = 0;
1127        for c in changes {
1128            match c {
1129                ChangeAction::Keep(..) => keeps += 1,
1130                ChangeAction::Insert(..) => inserts += 1,
1131                ChangeAction::Remove(..) => removes += 1,
1132                ChangeAction::Update(..) => updates += 1,
1133                ChangeAction::Move(..) => moves += 1,
1134            }
1135        }
1136        (keeps, inserts, removes, updates, moves)
1137    }
1138
1139    #[test]
1140    fn test_identical_queries_are_all_keep() {
1141        let sql = "SELECT a, b FROM t WHERE a > 1";
1142        let source = parse(sql, Dialect::Ansi).unwrap();
1143        let target = parse(sql, Dialect::Ansi).unwrap();
1144        let changes = diff(&source, &target);
1145        let (keeps, inserts, removes, updates, _moves) = count_by_action(&changes);
1146        assert!(keeps > 0, "should have keep actions");
1147        assert_eq!(inserts, 0, "no inserts for identical queries");
1148        assert_eq!(removes, 0, "no removes for identical queries");
1149        assert_eq!(updates, 0, "no updates for identical queries");
1150    }
1151
1152    #[test]
1153    fn test_column_added() {
1154        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1155        let target = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1156        let changes = diff(&source, &target);
1157        let (keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1158        assert!(keeps > 0);
1159        assert!(inserts > 0, "should have insert for new column b");
1160        assert_eq!(removes, 0);
1161    }
1162
1163    #[test]
1164    fn test_column_removed() {
1165        let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1166        let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1167        let changes = diff(&source, &target);
1168        let (keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1169        assert!(keeps > 0);
1170        assert!(removes > 0, "should have remove for column b");
1171    }
1172
1173    #[test]
1174    fn test_column_changed() {
1175        let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1176        let target = parse("SELECT a, c FROM t", Dialect::Ansi).unwrap();
1177        let changes = diff(&source, &target);
1178        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1179        assert!(updates > 0, "should have update for b -> c");
1180    }
1181
1182    #[test]
1183    fn test_where_clause_added() {
1184        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1185        let target = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1186        let changes = diff(&source, &target);
1187        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1188        assert!(inserts > 0, "should have insert for WHERE clause");
1189    }
1190
1191    #[test]
1192    fn test_where_clause_removed() {
1193        let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1194        let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1195        let changes = diff(&source, &target);
1196        let (_keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1197        assert!(removes > 0, "should have remove for WHERE clause");
1198    }
1199
1200    #[test]
1201    fn test_where_clause_updated() {
1202        let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1203        let target = parse("SELECT a FROM t WHERE a > 2", Dialect::Ansi).unwrap();
1204        let changes = diff(&source, &target);
1205        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1206        assert!(updates > 0, "should have update for WHERE value change");
1207    }
1208
1209    #[test]
1210    fn test_table_changed() {
1211        let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1212        let target = parse("SELECT a FROM t2", Dialect::Ansi).unwrap();
1213        let changes = diff(&source, &target);
1214        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1215        assert!(updates > 0, "should have update for table change");
1216    }
1217
1218    #[test]
1219    fn test_join_added() {
1220        let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1221        let target = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
1222        let changes = diff(&source, &target);
1223        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1224        assert!(inserts > 0, "should have insert for JOIN");
1225    }
1226
1227    #[test]
1228    fn test_order_by_changed() {
1229        let source = parse("SELECT a FROM t ORDER BY a ASC", Dialect::Ansi).unwrap();
1230        let target = parse("SELECT a FROM t ORDER BY a DESC", Dialect::Ansi).unwrap();
1231        let changes = diff(&source, &target);
1232        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1233        assert!(updates > 0, "should have update for ORDER BY direction");
1234    }
1235
1236    #[test]
1237    fn test_complex_nested_query() {
1238        let source = parse(
1239            "SELECT a, b FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 0)",
1240            Dialect::Ansi,
1241        )
1242        .unwrap();
1243        let target = parse(
1244            "SELECT a, c FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 5)",
1245            Dialect::Ansi,
1246        )
1247        .unwrap();
1248        let changes = diff(&source, &target);
1249        let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1250        assert!(keeps > 0, "unchanged parts should be kept");
1251        assert!(updates > 0, "changed parts should be updated (b->c, 0->5)");
1252    }
1253
1254    #[test]
1255    fn test_different_statement_types() {
1256        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1257        let target = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
1258        let changes = diff(&source, &target);
1259        let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1260        assert!(removes > 0, "source should be removed");
1261        assert!(inserts > 0, "target should be inserted");
1262    }
1263
1264    #[test]
1265    fn test_cte_added() {
1266        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1267        let target = parse("WITH cte AS (SELECT 1 AS x) SELECT a FROM t", Dialect::Ansi).unwrap();
1268        let changes = diff(&source, &target);
1269        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1270        assert!(inserts > 0, "should have insert for CTE");
1271    }
1272
1273    #[test]
1274    fn test_limit_changed() {
1275        let source = parse("SELECT a FROM t LIMIT 10", Dialect::Ansi).unwrap();
1276        let target = parse("SELECT a FROM t LIMIT 20", Dialect::Ansi).unwrap();
1277        let changes = diff(&source, &target);
1278        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1279        assert!(updates > 0, "should have update for LIMIT change");
1280    }
1281
1282    #[test]
1283    fn test_group_by_added() {
1284        let source = parse("SELECT a, COUNT(*) FROM t", Dialect::Ansi).unwrap();
1285        let target = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1286        let changes = diff(&source, &target);
1287        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1288        assert!(inserts > 0, "should have insert for GROUP BY");
1289    }
1290
1291    #[test]
1292    fn test_diff_sql_convenience() {
1293        let changes = diff_sql("SELECT a FROM t", "SELECT a, b FROM t", Dialect::Ansi).unwrap();
1294        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1295        assert!(inserts > 0);
1296    }
1297
1298    #[test]
1299    fn test_having_added() {
1300        let source = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1301        let target = parse(
1302            "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
1303            Dialect::Ansi,
1304        )
1305        .unwrap();
1306        let changes = diff(&source, &target);
1307        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1308        assert!(inserts > 0, "should have insert for HAVING");
1309    }
1310
1311    #[test]
1312    fn test_create_table_column_diff() {
1313        let source = parse("CREATE TABLE t (a INT, b TEXT)", Dialect::Ansi).unwrap();
1314        let target = parse("CREATE TABLE t (a INT, c TEXT)", Dialect::Ansi).unwrap();
1315        let changes = diff(&source, &target);
1316        let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1317        assert!(removes > 0, "should remove column b");
1318        assert!(inserts > 0, "should insert column c");
1319    }
1320
1321    #[test]
1322    fn test_union_diff() {
1323        let source = parse("SELECT a FROM t1 UNION SELECT b FROM t2", Dialect::Ansi).unwrap();
1324        let target = parse("SELECT a FROM t1 UNION SELECT c FROM t2", Dialect::Ansi).unwrap();
1325        let changes = diff(&source, &target);
1326        let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1327        assert!(keeps > 0);
1328        assert!(updates > 0, "should have update for b -> c");
1329    }
1330}