sqrust_rules/ambiguous/
or_in_join_condition.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Join, JoinConstraint, JoinOperator, Query, SetExpr,
3 Statement, TableFactor};
4
5pub struct OrInJoinCondition;
6
7impl Rule for OrInJoinCondition {
8 fn name(&self) -> &'static str {
9 "Ambiguous/OrInJoinCondition"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 for stmt in &ctx.statements {
19 if let Statement::Query(query) = stmt {
20 check_query(query, &ctx.source, &mut diags);
21 }
22 }
23 diags
24 }
25}
26
27fn check_query(query: &Query, source: &str, diags: &mut Vec<Diagnostic>) {
28 if let Some(with) = &query.with {
30 for cte in &with.cte_tables {
31 check_query(&cte.query, source, diags);
32 }
33 }
34 check_set_expr(&query.body, source, diags);
35}
36
37fn check_set_expr(expr: &SetExpr, source: &str, diags: &mut Vec<Diagnostic>) {
38 match expr {
39 SetExpr::Select(select) => {
40 for twj in &select.from {
41 for join in &twj.joins {
43 check_join(join, source, diags);
44 }
45 recurse_table_factor(&twj.relation, source, diags);
47 for join in &twj.joins {
48 recurse_table_factor(&join.relation, source, diags);
49 }
50 }
51 }
52 SetExpr::SetOperation { left, right, .. } => {
53 check_set_expr(left, source, diags);
54 check_set_expr(right, source, diags);
55 }
56 SetExpr::Query(inner) => {
57 check_query(inner, source, diags);
58 }
59 _ => {}
60 }
61}
62
63fn on_expr(join: &Join) -> Option<&Expr> {
65 match &join.join_operator {
66 JoinOperator::Inner(JoinConstraint::On(e))
67 | JoinOperator::LeftOuter(JoinConstraint::On(e))
68 | JoinOperator::RightOuter(JoinConstraint::On(e))
69 | JoinOperator::FullOuter(JoinConstraint::On(e)) => Some(e),
70 _ => None,
71 }
72}
73
74fn check_join(join: &Join, source: &str, diags: &mut Vec<Diagnostic>) {
76 if let Some(expr) = on_expr(join) {
77 if has_or(expr) {
78 let (line, col) = find_or_position(source);
79 diags.push(Diagnostic {
80 rule: "Ambiguous/OrInJoinCondition",
81 message: "OR condition in JOIN ON clause; this may produce unintended cross-join-like results"
82 .to_string(),
83 line,
84 col,
85 });
86 }
87 }
88}
89
90fn has_or(expr: &Expr) -> bool {
92 match expr {
93 Expr::BinaryOp {
94 op: BinaryOperator::Or,
95 ..
96 } => true,
97 Expr::BinaryOp { left, right, .. } => has_or(left) || has_or(right),
98 Expr::Nested(e) => has_or(e),
99 Expr::UnaryOp { expr: e, .. } => has_or(e),
100 _ => false,
101 }
102}
103
104fn recurse_table_factor(tf: &TableFactor, source: &str, diags: &mut Vec<Diagnostic>) {
106 if let TableFactor::Derived { subquery, .. } = tf {
107 check_query(subquery, source, diags);
108 }
109}
110
111fn find_or_position(source: &str) -> (usize, usize) {
114 let bytes = source.as_bytes();
115 let len = bytes.len();
116 let kw = b"OR";
117 let kw_len = kw.len();
118
119 let mut i = 0;
120 while i + kw_len <= len {
121 let matches = bytes[i].eq_ignore_ascii_case(&kw[0])
122 && bytes[i + 1].eq_ignore_ascii_case(&kw[1]);
123 if matches {
124 let before_ok =
125 i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
126 let after = i + kw_len;
127 let after_ok = after >= len
128 || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
129 if before_ok && after_ok {
130 return offset_to_line_col(source, i);
131 }
132 }
133 i += 1;
134 }
135
136 (1, 1)
137}
138
139fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
141 let before = &source[..offset];
142 let line = before.chars().filter(|&c| c == '\n').count() + 1;
143 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
144 (line, col)
145}