1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, SetExpr, Statement,
3 TableFactor, Value};
4
5pub struct DivisionByZero;
6
7impl Rule for DivisionByZero {
8 fn name(&self) -> &'static str {
9 "Ambiguous/DivisionByZero"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 for stmt in &ctx.statements {
19 collect_from_statement(stmt, ctx, &mut diags);
20 }
21 diags
22 }
23}
24
25fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
26 if let Statement::Query(query) = stmt {
27 collect_from_query(query, ctx, diags);
28 }
29}
30
31fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
32 if let Some(with) = &query.with {
33 for cte in &with.cte_tables {
34 collect_from_query(&cte.query, ctx, diags);
35 }
36 }
37 collect_from_set_expr(&query.body, ctx, diags);
38}
39
40fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
41 match expr {
42 SetExpr::Select(select) => {
43 collect_from_select(select, ctx, diags);
44 }
45 SetExpr::Query(inner) => {
46 collect_from_query(inner, ctx, diags);
47 }
48 SetExpr::SetOperation { left, right, .. } => {
49 collect_from_set_expr(left, ctx, diags);
50 collect_from_set_expr(right, ctx, diags);
51 }
52 _ => {}
53 }
54}
55
56fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
57 for item in &select.projection {
59 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
60 check_expr(e, ctx, diags);
61 }
62 }
63
64 for twj in &select.from {
66 collect_from_table_factor(&twj.relation, ctx, diags);
67 for join in &twj.joins {
68 collect_from_table_factor(&join.relation, ctx, diags);
69 if let sqlparser::ast::JoinOperator::Inner(sqlparser::ast::JoinConstraint::On(e))
71 | sqlparser::ast::JoinOperator::LeftOuter(sqlparser::ast::JoinConstraint::On(e))
72 | sqlparser::ast::JoinOperator::RightOuter(sqlparser::ast::JoinConstraint::On(e))
73 | sqlparser::ast::JoinOperator::FullOuter(sqlparser::ast::JoinConstraint::On(e)) =
74 &join.join_operator
75 {
76 check_expr(e, ctx, diags);
77 }
78 }
79 }
80
81 if let Some(selection) = &select.selection {
83 check_expr(selection, ctx, diags);
84 }
85
86 if let Some(having) = &select.having {
88 check_expr(having, ctx, diags);
89 }
90}
91
92fn collect_from_table_factor(
93 factor: &TableFactor,
94 ctx: &FileContext,
95 diags: &mut Vec<Diagnostic>,
96) {
97 if let TableFactor::Derived { subquery, .. } = factor {
98 collect_from_query(subquery, ctx, diags);
99 }
100}
101
102fn is_zero_literal(expr: &Expr) -> bool {
104 if let Expr::Value(Value::Number(s, _)) = expr {
105 s.parse::<f64>().map(|v| v == 0.0).unwrap_or(false)
107 } else {
108 false
109 }
110}
111
112fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
113 match expr {
114 Expr::BinaryOp { left, op, right } => {
115 check_expr(left, ctx, diags);
117 check_expr(right, ctx, diags);
118
119 if matches!(op, BinaryOperator::Divide) && is_zero_literal(right) {
120 let (line, col) = find_division_position(&ctx.source);
121 diags.push(Diagnostic {
122 rule: "Ambiguous/DivisionByZero",
123 message: "Division by zero literal; this will cause an error or return NULL"
124 .to_string(),
125 line,
126 col,
127 });
128 }
129 }
130 Expr::UnaryOp { expr: inner, .. } => {
131 check_expr(inner, ctx, diags);
132 }
133 Expr::Nested(inner) => {
134 check_expr(inner, ctx, diags);
135 }
136 Expr::Case {
137 operand,
138 conditions,
139 results,
140 else_result,
141 } => {
142 if let Some(op) = operand {
143 check_expr(op, ctx, diags);
144 }
145 for cond in conditions {
146 check_expr(cond, ctx, diags);
147 }
148 for result in results {
149 check_expr(result, ctx, diags);
150 }
151 if let Some(else_e) = else_result {
152 check_expr(else_e, ctx, diags);
153 }
154 }
155 Expr::InList { expr: inner, list, .. } => {
156 check_expr(inner, ctx, diags);
157 for e in list {
158 check_expr(e, ctx, diags);
159 }
160 }
161 Expr::Between {
162 expr: inner,
163 low,
164 high,
165 ..
166 } => {
167 check_expr(inner, ctx, diags);
168 check_expr(low, ctx, diags);
169 check_expr(high, ctx, diags);
170 }
171 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
172 check_expr(inner, ctx, diags);
173 }
174 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
175 collect_from_query(q, ctx, diags);
176 }
177 _ => {}
178 }
179}
180
181fn find_division_position(source: &str) -> (usize, usize) {
187 let bytes = source.as_bytes();
188 let len = bytes.len();
189 let mut i = 0;
190
191 while i < len {
192 if bytes[i] == b'/' {
193 let mut j = i + 1;
195 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
196 j += 1;
197 }
198 if j < len && bytes[j].is_ascii_digit() {
200 let start = j;
201 while j < len && (bytes[j].is_ascii_digit() || bytes[j] == b'.') {
203 j += 1;
204 }
205 let token = std::str::from_utf8(&bytes[start..j]).unwrap_or("");
206 if token.parse::<f64>().map(|v| v == 0.0).unwrap_or(false) {
207 return offset_to_line_col(source, i);
208 }
209 }
210 }
211 i += 1;
212 }
213
214 (1, 1)
215}
216
217fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
219 let before = &source[..offset];
220 let line = before.chars().filter(|&c| c == '\n').count() + 1;
221 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
222 (line, col)
223}