1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, SetExpr, Statement,
3 TableFactor, Value};
4
5pub struct IntegerDivision;
6
7impl Rule for IntegerDivision {
8 fn name(&self) -> &'static str {
9 "Ambiguous/IntegerDivision"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 for stmt in &ctx.statements {
19 collect_from_statement(stmt, ctx, &mut diags);
20 }
21 diags
22 }
23}
24
25fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
26 if let Statement::Query(query) = stmt {
27 collect_from_query(query, ctx, diags);
28 }
29}
30
31fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
32 if let Some(with) = &query.with {
33 for cte in &with.cte_tables {
34 collect_from_query(&cte.query, ctx, diags);
35 }
36 }
37 collect_from_set_expr(&query.body, ctx, diags);
38}
39
40fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
41 match expr {
42 SetExpr::Select(select) => {
43 collect_from_select(select, ctx, diags);
44 }
45 SetExpr::Query(inner) => {
46 collect_from_query(inner, ctx, diags);
47 }
48 SetExpr::SetOperation { left, right, .. } => {
49 collect_from_set_expr(left, ctx, diags);
50 collect_from_set_expr(right, ctx, diags);
51 }
52 _ => {}
53 }
54}
55
56fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
57 for item in &select.projection {
59 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
60 check_expr(e, ctx, diags);
61 }
62 }
63
64 for twj in &select.from {
66 collect_from_table_factor(&twj.relation, ctx, diags);
67 for join in &twj.joins {
68 collect_from_table_factor(&join.relation, ctx, diags);
69 if let sqlparser::ast::JoinOperator::Inner(sqlparser::ast::JoinConstraint::On(e))
70 | sqlparser::ast::JoinOperator::LeftOuter(sqlparser::ast::JoinConstraint::On(e))
71 | sqlparser::ast::JoinOperator::RightOuter(sqlparser::ast::JoinConstraint::On(e))
72 | sqlparser::ast::JoinOperator::FullOuter(sqlparser::ast::JoinConstraint::On(e)) =
73 &join.join_operator
74 {
75 check_expr(e, ctx, diags);
76 }
77 }
78 }
79
80 if let Some(selection) = &select.selection {
82 check_expr(selection, ctx, diags);
83 }
84
85 if let Some(having) = &select.having {
87 check_expr(having, ctx, diags);
88 }
89}
90
91fn collect_from_table_factor(
92 factor: &TableFactor,
93 ctx: &FileContext,
94 diags: &mut Vec<Diagnostic>,
95) {
96 if let TableFactor::Derived { subquery, .. } = factor {
97 collect_from_query(subquery, ctx, diags);
98 }
99}
100
101fn is_integer_literal(expr: &Expr) -> Option<String> {
103 if let Expr::Value(Value::Number(s, _)) = expr {
104 if !s.contains('.') {
105 return Some(s.clone());
106 }
107 }
108 None
109}
110
111fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
112 match expr {
113 Expr::BinaryOp { left, op, right } => {
114 check_expr(left, ctx, diags);
116 check_expr(right, ctx, diags);
117
118 if matches!(op, BinaryOperator::Divide) {
119 if let (Some(lval), Some(rval)) =
120 (is_integer_literal(left), is_integer_literal(right))
121 {
122 let (line, col) = find_integer_division_position(&ctx.source, &lval, &rval);
123 diags.push(Diagnostic {
124 rule: "Ambiguous/IntegerDivision",
125 message: format!(
126 "Integer division {} / {} truncates towards zero \
127 — use CAST(expr AS FLOAT) or add .0 to a literal for decimal division",
128 lval, rval
129 ),
130 line,
131 col,
132 });
133 }
134 }
135 }
136 Expr::UnaryOp { expr: inner, .. } => {
137 check_expr(inner, ctx, diags);
138 }
139 Expr::Nested(inner) => {
140 check_expr(inner, ctx, diags);
141 }
142 Expr::Case {
143 operand,
144 conditions,
145 results,
146 else_result,
147 } => {
148 if let Some(op) = operand {
149 check_expr(op, ctx, diags);
150 }
151 for cond in conditions {
152 check_expr(cond, ctx, diags);
153 }
154 for result in results {
155 check_expr(result, ctx, diags);
156 }
157 if let Some(else_e) = else_result {
158 check_expr(else_e, ctx, diags);
159 }
160 }
161 Expr::InList { expr: inner, list, .. } => {
162 check_expr(inner, ctx, diags);
163 for e in list {
164 check_expr(e, ctx, diags);
165 }
166 }
167 Expr::Between {
168 expr: inner,
169 low,
170 high,
171 ..
172 } => {
173 check_expr(inner, ctx, diags);
174 check_expr(low, ctx, diags);
175 check_expr(high, ctx, diags);
176 }
177 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
178 check_expr(inner, ctx, diags);
179 }
180 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
181 collect_from_query(q, ctx, diags);
182 }
183 _ => {}
184 }
185}
186
187fn find_integer_division_position(source: &str, lval: &str, rval: &str) -> (usize, usize) {
191 let bytes = source.as_bytes();
192 let len = bytes.len();
193 let lval_bytes = lval.as_bytes();
194 let rval_bytes = rval.as_bytes();
195 let llen = lval_bytes.len();
196 let rlen = rval_bytes.len();
197
198 let mut i = 0;
199 while i + llen <= len {
200 if &bytes[i..i + llen] == lval_bytes {
202 let after_l = i + llen;
203 let mut j = after_l;
205 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
206 j += 1;
207 }
208 if j < len && bytes[j] == b'/' {
209 let slash_pos = j;
210 j += 1;
211 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
213 j += 1;
214 }
215 if j + rlen <= len && &bytes[j..j + rlen] == rval_bytes {
217 let after_r = j + rlen;
219 let rval_ends =
220 after_r >= len || (!bytes[after_r].is_ascii_digit() && bytes[after_r] != b'.');
221 if rval_ends {
222 return offset_to_line_col(source, slash_pos);
223 }
224 }
225 }
226 }
227 i += 1;
228 }
229
230 (1, 1)
231}
232
233fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
235 let before = &source[..offset];
236 let line = before.chars().filter(|&c| c == '\n').count() + 1;
237 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
238 (line, col)
239}