Skip to main content

sqrust_rules/ambiguous/
integer_division.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, SetExpr, Statement,
3    TableFactor, Value};
4
5pub struct IntegerDivision;
6
7impl Rule for IntegerDivision {
8    fn name(&self) -> &'static str {
9        "Ambiguous/IntegerDivision"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16
17        let mut diags = Vec::new();
18        for stmt in &ctx.statements {
19            collect_from_statement(stmt, ctx, &mut diags);
20        }
21        diags
22    }
23}
24
25fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
26    if let Statement::Query(query) = stmt {
27        collect_from_query(query, ctx, diags);
28    }
29}
30
31fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
32    if let Some(with) = &query.with {
33        for cte in &with.cte_tables {
34            collect_from_query(&cte.query, ctx, diags);
35        }
36    }
37    collect_from_set_expr(&query.body, ctx, diags);
38}
39
40fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
41    match expr {
42        SetExpr::Select(select) => {
43            collect_from_select(select, ctx, diags);
44        }
45        SetExpr::Query(inner) => {
46            collect_from_query(inner, ctx, diags);
47        }
48        SetExpr::SetOperation { left, right, .. } => {
49            collect_from_set_expr(left, ctx, diags);
50            collect_from_set_expr(right, ctx, diags);
51        }
52        _ => {}
53    }
54}
55
56fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
57    // Check SELECT projection.
58    for item in &select.projection {
59        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
60            check_expr(e, ctx, diags);
61        }
62    }
63
64    // Check FROM subqueries.
65    for twj in &select.from {
66        collect_from_table_factor(&twj.relation, ctx, diags);
67        for join in &twj.joins {
68            collect_from_table_factor(&join.relation, ctx, diags);
69            if let sqlparser::ast::JoinOperator::Inner(sqlparser::ast::JoinConstraint::On(e))
70            | sqlparser::ast::JoinOperator::LeftOuter(sqlparser::ast::JoinConstraint::On(e))
71            | sqlparser::ast::JoinOperator::RightOuter(sqlparser::ast::JoinConstraint::On(e))
72            | sqlparser::ast::JoinOperator::FullOuter(sqlparser::ast::JoinConstraint::On(e)) =
73                &join.join_operator
74            {
75                check_expr(e, ctx, diags);
76            }
77        }
78    }
79
80    // Check WHERE.
81    if let Some(selection) = &select.selection {
82        check_expr(selection, ctx, diags);
83    }
84
85    // Check HAVING.
86    if let Some(having) = &select.having {
87        check_expr(having, ctx, diags);
88    }
89}
90
91fn collect_from_table_factor(
92    factor: &TableFactor,
93    ctx: &FileContext,
94    diags: &mut Vec<Diagnostic>,
95) {
96    if let TableFactor::Derived { subquery, .. } = factor {
97        collect_from_query(subquery, ctx, diags);
98    }
99}
100
101/// Returns `true` when `expr` is an integer literal (no decimal point).
102fn is_integer_literal(expr: &Expr) -> Option<String> {
103    if let Expr::Value(Value::Number(s, _)) = expr {
104        if !s.contains('.') {
105            return Some(s.clone());
106        }
107    }
108    None
109}
110
111fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
112    match expr {
113        Expr::BinaryOp { left, op, right } => {
114            // Recurse into children first so nested divisions are caught.
115            check_expr(left, ctx, diags);
116            check_expr(right, ctx, diags);
117
118            if matches!(op, BinaryOperator::Divide) {
119                if let (Some(lval), Some(rval)) =
120                    (is_integer_literal(left), is_integer_literal(right))
121                {
122                    let (line, col) = find_integer_division_position(&ctx.source, &lval, &rval);
123                    diags.push(Diagnostic {
124                        rule: "Ambiguous/IntegerDivision",
125                        message: format!(
126                            "Integer division {} / {} truncates towards zero \
127                             — use CAST(expr AS FLOAT) or add .0 to a literal for decimal division",
128                            lval, rval
129                        ),
130                        line,
131                        col,
132                    });
133                }
134            }
135        }
136        Expr::UnaryOp { expr: inner, .. } => {
137            check_expr(inner, ctx, diags);
138        }
139        Expr::Nested(inner) => {
140            check_expr(inner, ctx, diags);
141        }
142        Expr::Case {
143            operand,
144            conditions,
145            results,
146            else_result,
147        } => {
148            if let Some(op) = operand {
149                check_expr(op, ctx, diags);
150            }
151            for cond in conditions {
152                check_expr(cond, ctx, diags);
153            }
154            for result in results {
155                check_expr(result, ctx, diags);
156            }
157            if let Some(else_e) = else_result {
158                check_expr(else_e, ctx, diags);
159            }
160        }
161        Expr::InList { expr: inner, list, .. } => {
162            check_expr(inner, ctx, diags);
163            for e in list {
164                check_expr(e, ctx, diags);
165            }
166        }
167        Expr::Between {
168            expr: inner,
169            low,
170            high,
171            ..
172        } => {
173            check_expr(inner, ctx, diags);
174            check_expr(low, ctx, diags);
175            check_expr(high, ctx, diags);
176        }
177        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
178            check_expr(inner, ctx, diags);
179        }
180        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
181            collect_from_query(q, ctx, diags);
182        }
183        _ => {}
184    }
185}
186
187/// Finds the position of an integer/integer division pattern (e.g. `1/2`) in source.
188/// Searches for `lval` followed by optional spaces, `/`, optional spaces, `rval`.
189/// Falls back to (1, 1) if not found.
190fn find_integer_division_position(source: &str, lval: &str, rval: &str) -> (usize, usize) {
191    let bytes = source.as_bytes();
192    let len = bytes.len();
193    let lval_bytes = lval.as_bytes();
194    let rval_bytes = rval.as_bytes();
195    let llen = lval_bytes.len();
196    let rlen = rval_bytes.len();
197
198    let mut i = 0;
199    while i + llen <= len {
200        // Match lval at position i — check word boundary after
201        if &bytes[i..i + llen] == lval_bytes {
202            let after_l = i + llen;
203            // Skip whitespace
204            let mut j = after_l;
205            while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
206                j += 1;
207            }
208            if j < len && bytes[j] == b'/' {
209                let slash_pos = j;
210                j += 1;
211                // Skip whitespace after /
212                while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
213                    j += 1;
214                }
215                // Match rval
216                if j + rlen <= len && &bytes[j..j + rlen] == rval_bytes {
217                    // Ensure the next char after rval is not a digit or dot (not part of a larger number)
218                    let after_r = j + rlen;
219                    let rval_ends =
220                        after_r >= len || (!bytes[after_r].is_ascii_digit() && bytes[after_r] != b'.');
221                    if rval_ends {
222                        return offset_to_line_col(source, slash_pos);
223                    }
224                }
225            }
226        }
227        i += 1;
228    }
229
230    (1, 1)
231}
232
233/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
234fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
235    let before = &source[..offset];
236    let line = before.chars().filter(|&c| c == '\n').count() + 1;
237    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
238    (line, col)
239}