Skip to main content

sqrust_rules/ambiguous/
self_comparison.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct SelfComparison;
5
6impl Rule for SelfComparison {
7    fn name(&self) -> &'static str {
8        "Ambiguous/SelfComparison"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            collect_from_statement(stmt, ctx, &mut diags);
19        }
20        diags
21    }
22}
23
24fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
25    if let Statement::Query(query) = stmt {
26        collect_from_query(query, ctx, diags);
27    }
28}
29
30fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
31    if let Some(with) = &query.with {
32        for cte in &with.cte_tables {
33            collect_from_query(&cte.query, ctx, diags);
34        }
35    }
36    collect_from_set_expr(&query.body, ctx, diags);
37}
38
39fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
40    match expr {
41        SetExpr::Select(select) => {
42            collect_from_select(select, ctx, diags);
43        }
44        SetExpr::Query(inner) => {
45            collect_from_query(inner, ctx, diags);
46        }
47        SetExpr::SetOperation { left, right, .. } => {
48            collect_from_set_expr(left, ctx, diags);
49            collect_from_set_expr(right, ctx, diags);
50        }
51        _ => {}
52    }
53}
54
55fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
56    // Check SELECT projection.
57    for item in &select.projection {
58        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
59            check_expr(e, ctx, diags);
60        }
61    }
62
63    // Check FROM subqueries.
64    for twj in &select.from {
65        collect_from_table_factor(&twj.relation, ctx, diags);
66        for join in &twj.joins {
67            collect_from_table_factor(&join.relation, ctx, diags);
68        }
69    }
70
71    // Check WHERE.
72    if let Some(selection) = &select.selection {
73        check_expr(selection, ctx, diags);
74    }
75
76    // Check HAVING.
77    if let Some(having) = &select.having {
78        check_expr(having, ctx, diags);
79    }
80}
81
82fn collect_from_table_factor(
83    factor: &TableFactor,
84    ctx: &FileContext,
85    diags: &mut Vec<Diagnostic>,
86) {
87    if let TableFactor::Derived { subquery, .. } = factor {
88        collect_from_query(subquery, ctx, diags);
89    }
90}
91
92/// Returns true when the operator is a comparison that would be trivially
93/// redundant when both operands are the same.
94fn is_comparison_op(op: &BinaryOperator) -> bool {
95    matches!(
96        op,
97        BinaryOperator::Eq
98            | BinaryOperator::NotEq
99            | BinaryOperator::Lt
100            | BinaryOperator::Gt
101            | BinaryOperator::LtEq
102            | BinaryOperator::GtEq
103    )
104}
105
106/// Strips `Nested` wrappers so that `(col)` resolves to the inner identifier.
107fn unwrap_nested(expr: &Expr) -> &Expr {
108    let mut current = expr;
109    while let Expr::Nested(inner) = current {
110        current = inner;
111    }
112    current
113}
114
115/// Case-insensitive equality check for identifier-like expressions.
116/// Both sides must be plain or compound identifiers. Returns the display name
117/// when they are equal, `None` otherwise.
118fn self_comparison_name<'a>(left: &'a Expr, right: &'a Expr) -> Option<String> {
119    let l = unwrap_nested(left);
120    let r = unwrap_nested(right);
121
122    match (l, r) {
123        (Expr::Identifier(li), Expr::Identifier(ri)) => {
124            if li.value.to_lowercase() == ri.value.to_lowercase() {
125                Some(li.value.clone())
126            } else {
127                None
128            }
129        }
130        (Expr::CompoundIdentifier(lparts), Expr::CompoundIdentifier(rparts)) => {
131            if lparts.len() == rparts.len()
132                && lparts
133                    .iter()
134                    .zip(rparts.iter())
135                    .all(|(a, b)| a.value.to_lowercase() == b.value.to_lowercase())
136            {
137                let name = lparts
138                    .iter()
139                    .map(|i| i.value.as_str())
140                    .collect::<Vec<_>>()
141                    .join(".");
142                Some(name)
143            } else {
144                None
145            }
146        }
147        _ => None,
148    }
149}
150
151fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
152    match expr {
153        Expr::BinaryOp { left, op, right } => {
154            // Recurse into children first.
155            check_expr(left, ctx, diags);
156            check_expr(right, ctx, diags);
157
158            // Then check whether this node is a self-comparison.
159            if is_comparison_op(op) {
160                if let Some(name) = self_comparison_name(left, right) {
161                    let (line, col) = find_identifier_position(&ctx.source, &name);
162                    diags.push(Diagnostic {
163                        rule: "Ambiguous/SelfComparison",
164                        message: format!(
165                            "Expression compares '{}' to itself; this is always TRUE or NULL",
166                            name
167                        ),
168                        line,
169                        col,
170                    });
171                }
172            }
173        }
174        Expr::UnaryOp { expr: inner, .. } => {
175            check_expr(inner, ctx, diags);
176        }
177        Expr::Nested(inner) => {
178            check_expr(inner, ctx, diags);
179        }
180        Expr::Case {
181            operand,
182            conditions,
183            results,
184            else_result,
185        } => {
186            if let Some(op) = operand {
187                check_expr(op, ctx, diags);
188            }
189            for cond in conditions {
190                check_expr(cond, ctx, diags);
191            }
192            for result in results {
193                check_expr(result, ctx, diags);
194            }
195            if let Some(else_e) = else_result {
196                check_expr(else_e, ctx, diags);
197            }
198        }
199        Expr::InList { expr: inner, list, .. } => {
200            check_expr(inner, ctx, diags);
201            for e in list {
202                check_expr(e, ctx, diags);
203            }
204        }
205        Expr::Between { expr: inner, low, high, .. } => {
206            check_expr(inner, ctx, diags);
207            check_expr(low, ctx, diags);
208            check_expr(high, ctx, diags);
209        }
210        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
211            check_expr(inner, ctx, diags);
212        }
213        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
214            collect_from_query(q, ctx, diags);
215        }
216        _ => {}
217    }
218}
219
220/// Finds the first word-boundary occurrence of `name` (case-insensitive) in
221/// `source` and returns a 1-indexed (line, col). Falls back to (1, 1).
222fn find_identifier_position(source: &str, name: &str) -> (usize, usize) {
223    let upper = source.to_uppercase();
224    let name_upper = name.to_uppercase();
225    let bytes = upper.as_bytes();
226    let name_bytes = name_upper.as_bytes();
227    let name_len = name_bytes.len();
228
229    let mut i = 0;
230    while i + name_len <= bytes.len() {
231        if bytes[i..i + name_len] == *name_bytes {
232            let before_ok =
233                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
234            let after = i + name_len;
235            let after_ok = after >= bytes.len()
236                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
237            if before_ok && after_ok {
238                return offset_to_line_col(source, i);
239            }
240        }
241        i += 1;
242    }
243    (1, 1)
244}
245
246/// Converts a byte offset to 1-indexed (line, col).
247fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
248    let before = &source[..offset];
249    let line = before.chars().filter(|&c| c == '\n').count() + 1;
250    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
251    (line, col)
252}