1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct SelfComparison;
5
6impl Rule for SelfComparison {
7 fn name(&self) -> &'static str {
8 "Ambiguous/SelfComparison"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
13 return Vec::new();
14 }
15
16 let mut diags = Vec::new();
17 for stmt in &ctx.statements {
18 collect_from_statement(stmt, ctx, &mut diags);
19 }
20 diags
21 }
22}
23
24fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
25 if let Statement::Query(query) = stmt {
26 collect_from_query(query, ctx, diags);
27 }
28}
29
30fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
31 if let Some(with) = &query.with {
32 for cte in &with.cte_tables {
33 collect_from_query(&cte.query, ctx, diags);
34 }
35 }
36 collect_from_set_expr(&query.body, ctx, diags);
37}
38
39fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
40 match expr {
41 SetExpr::Select(select) => {
42 collect_from_select(select, ctx, diags);
43 }
44 SetExpr::Query(inner) => {
45 collect_from_query(inner, ctx, diags);
46 }
47 SetExpr::SetOperation { left, right, .. } => {
48 collect_from_set_expr(left, ctx, diags);
49 collect_from_set_expr(right, ctx, diags);
50 }
51 _ => {}
52 }
53}
54
55fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
56 for item in &select.projection {
58 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
59 check_expr(e, ctx, diags);
60 }
61 }
62
63 for twj in &select.from {
65 collect_from_table_factor(&twj.relation, ctx, diags);
66 for join in &twj.joins {
67 collect_from_table_factor(&join.relation, ctx, diags);
68 }
69 }
70
71 if let Some(selection) = &select.selection {
73 check_expr(selection, ctx, diags);
74 }
75
76 if let Some(having) = &select.having {
78 check_expr(having, ctx, diags);
79 }
80}
81
82fn collect_from_table_factor(
83 factor: &TableFactor,
84 ctx: &FileContext,
85 diags: &mut Vec<Diagnostic>,
86) {
87 if let TableFactor::Derived { subquery, .. } = factor {
88 collect_from_query(subquery, ctx, diags);
89 }
90}
91
92fn is_comparison_op(op: &BinaryOperator) -> bool {
95 matches!(
96 op,
97 BinaryOperator::Eq
98 | BinaryOperator::NotEq
99 | BinaryOperator::Lt
100 | BinaryOperator::Gt
101 | BinaryOperator::LtEq
102 | BinaryOperator::GtEq
103 )
104}
105
106fn unwrap_nested(expr: &Expr) -> &Expr {
108 let mut current = expr;
109 while let Expr::Nested(inner) = current {
110 current = inner;
111 }
112 current
113}
114
115fn self_comparison_name<'a>(left: &'a Expr, right: &'a Expr) -> Option<String> {
119 let l = unwrap_nested(left);
120 let r = unwrap_nested(right);
121
122 match (l, r) {
123 (Expr::Identifier(li), Expr::Identifier(ri)) => {
124 if li.value.to_lowercase() == ri.value.to_lowercase() {
125 Some(li.value.clone())
126 } else {
127 None
128 }
129 }
130 (Expr::CompoundIdentifier(lparts), Expr::CompoundIdentifier(rparts)) => {
131 if lparts.len() == rparts.len()
132 && lparts
133 .iter()
134 .zip(rparts.iter())
135 .all(|(a, b)| a.value.to_lowercase() == b.value.to_lowercase())
136 {
137 let name = lparts
138 .iter()
139 .map(|i| i.value.as_str())
140 .collect::<Vec<_>>()
141 .join(".");
142 Some(name)
143 } else {
144 None
145 }
146 }
147 _ => None,
148 }
149}
150
151fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
152 match expr {
153 Expr::BinaryOp { left, op, right } => {
154 check_expr(left, ctx, diags);
156 check_expr(right, ctx, diags);
157
158 if is_comparison_op(op) {
160 if let Some(name) = self_comparison_name(left, right) {
161 let (line, col) = find_identifier_position(&ctx.source, &name);
162 diags.push(Diagnostic {
163 rule: "Ambiguous/SelfComparison",
164 message: format!(
165 "Expression compares '{}' to itself; this is always TRUE or NULL",
166 name
167 ),
168 line,
169 col,
170 });
171 }
172 }
173 }
174 Expr::UnaryOp { expr: inner, .. } => {
175 check_expr(inner, ctx, diags);
176 }
177 Expr::Nested(inner) => {
178 check_expr(inner, ctx, diags);
179 }
180 Expr::Case {
181 operand,
182 conditions,
183 results,
184 else_result,
185 } => {
186 if let Some(op) = operand {
187 check_expr(op, ctx, diags);
188 }
189 for cond in conditions {
190 check_expr(cond, ctx, diags);
191 }
192 for result in results {
193 check_expr(result, ctx, diags);
194 }
195 if let Some(else_e) = else_result {
196 check_expr(else_e, ctx, diags);
197 }
198 }
199 Expr::InList { expr: inner, list, .. } => {
200 check_expr(inner, ctx, diags);
201 for e in list {
202 check_expr(e, ctx, diags);
203 }
204 }
205 Expr::Between { expr: inner, low, high, .. } => {
206 check_expr(inner, ctx, diags);
207 check_expr(low, ctx, diags);
208 check_expr(high, ctx, diags);
209 }
210 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
211 check_expr(inner, ctx, diags);
212 }
213 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
214 collect_from_query(q, ctx, diags);
215 }
216 _ => {}
217 }
218}
219
220fn find_identifier_position(source: &str, name: &str) -> (usize, usize) {
223 let upper = source.to_uppercase();
224 let name_upper = name.to_uppercase();
225 let bytes = upper.as_bytes();
226 let name_bytes = name_upper.as_bytes();
227 let name_len = name_bytes.len();
228
229 let mut i = 0;
230 while i + name_len <= bytes.len() {
231 if bytes[i..i + name_len] == *name_bytes {
232 let before_ok =
233 i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
234 let after = i + name_len;
235 let after_ok = after >= bytes.len()
236 || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
237 if before_ok && after_ok {
238 return offset_to_line_col(source, i);
239 }
240 }
241 i += 1;
242 }
243 (1, 1)
244}
245
246fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
248 let before = &source[..offset];
249 let line = before.chars().filter(|&c| c == '\n').count() + 1;
250 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
251 (line, col)
252}