Skip to main content

sqrust_rules/ambiguous/
division_by_zero.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, SetExpr, Statement,
3    TableFactor, Value};
4
5pub struct DivisionByZero;
6
7impl Rule for DivisionByZero {
8    fn name(&self) -> &'static str {
9        "Ambiguous/DivisionByZero"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16
17        let mut diags = Vec::new();
18        for stmt in &ctx.statements {
19            collect_from_statement(stmt, ctx, &mut diags);
20        }
21        diags
22    }
23}
24
25fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
26    if let Statement::Query(query) = stmt {
27        collect_from_query(query, ctx, diags);
28    }
29}
30
31fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
32    if let Some(with) = &query.with {
33        for cte in &with.cte_tables {
34            collect_from_query(&cte.query, ctx, diags);
35        }
36    }
37    collect_from_set_expr(&query.body, ctx, diags);
38}
39
40fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
41    match expr {
42        SetExpr::Select(select) => {
43            collect_from_select(select, ctx, diags);
44        }
45        SetExpr::Query(inner) => {
46            collect_from_query(inner, ctx, diags);
47        }
48        SetExpr::SetOperation { left, right, .. } => {
49            collect_from_set_expr(left, ctx, diags);
50            collect_from_set_expr(right, ctx, diags);
51        }
52        _ => {}
53    }
54}
55
56fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
57    // Check SELECT projection.
58    for item in &select.projection {
59        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
60            check_expr(e, ctx, diags);
61        }
62    }
63
64    // Check FROM subqueries.
65    for twj in &select.from {
66        collect_from_table_factor(&twj.relation, ctx, diags);
67        for join in &twj.joins {
68            collect_from_table_factor(&join.relation, ctx, diags);
69            // Check any ON condition in the join.
70            if let sqlparser::ast::JoinOperator::Inner(sqlparser::ast::JoinConstraint::On(e))
71            | sqlparser::ast::JoinOperator::LeftOuter(sqlparser::ast::JoinConstraint::On(e))
72            | sqlparser::ast::JoinOperator::RightOuter(sqlparser::ast::JoinConstraint::On(e))
73            | sqlparser::ast::JoinOperator::FullOuter(sqlparser::ast::JoinConstraint::On(e)) =
74                &join.join_operator
75            {
76                check_expr(e, ctx, diags);
77            }
78        }
79    }
80
81    // Check WHERE.
82    if let Some(selection) = &select.selection {
83        check_expr(selection, ctx, diags);
84    }
85
86    // Check HAVING.
87    if let Some(having) = &select.having {
88        check_expr(having, ctx, diags);
89    }
90}
91
92fn collect_from_table_factor(
93    factor: &TableFactor,
94    ctx: &FileContext,
95    diags: &mut Vec<Diagnostic>,
96) {
97    if let TableFactor::Derived { subquery, .. } = factor {
98        collect_from_query(subquery, ctx, diags);
99    }
100}
101
102/// Returns `true` when `expr` is a numeric literal whose value is zero.
103fn is_zero_literal(expr: &Expr) -> bool {
104    if let Expr::Value(Value::Number(s, _)) = expr {
105        // Parse as f64 to handle 0, 0.0, 0.00, etc.
106        s.parse::<f64>().map(|v| v == 0.0).unwrap_or(false)
107    } else {
108        false
109    }
110}
111
112fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
113    match expr {
114        Expr::BinaryOp { left, op, right } => {
115            // Recurse into children first so nested / 0 inside `a / 2 / 0` is caught.
116            check_expr(left, ctx, diags);
117            check_expr(right, ctx, diags);
118
119            if matches!(op, BinaryOperator::Divide) && is_zero_literal(right) {
120                let (line, col) = find_division_position(&ctx.source);
121                diags.push(Diagnostic {
122                    rule: "Ambiguous/DivisionByZero",
123                    message: "Division by zero literal; this will cause an error or return NULL"
124                        .to_string(),
125                    line,
126                    col,
127                });
128            }
129        }
130        Expr::UnaryOp { expr: inner, .. } => {
131            check_expr(inner, ctx, diags);
132        }
133        Expr::Nested(inner) => {
134            check_expr(inner, ctx, diags);
135        }
136        Expr::Case {
137            operand,
138            conditions,
139            results,
140            else_result,
141        } => {
142            if let Some(op) = operand {
143                check_expr(op, ctx, diags);
144            }
145            for cond in conditions {
146                check_expr(cond, ctx, diags);
147            }
148            for result in results {
149                check_expr(result, ctx, diags);
150            }
151            if let Some(else_e) = else_result {
152                check_expr(else_e, ctx, diags);
153            }
154        }
155        Expr::InList { expr: inner, list, .. } => {
156            check_expr(inner, ctx, diags);
157            for e in list {
158                check_expr(e, ctx, diags);
159            }
160        }
161        Expr::Between {
162            expr: inner,
163            low,
164            high,
165            ..
166        } => {
167            check_expr(inner, ctx, diags);
168            check_expr(low, ctx, diags);
169            check_expr(high, ctx, diags);
170        }
171        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
172            check_expr(inner, ctx, diags);
173        }
174        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
175            collect_from_query(q, ctx, diags);
176        }
177        _ => {}
178    }
179}
180
181/// Finds the first `/ 0` or `/ 0.0` (etc.) pattern in `source` and returns
182/// its 1-indexed (line, col). Falls back to (1, 1) if not found.
183///
184/// Scans for a `/` character followed by optional whitespace followed by a
185/// zero-value numeric token (`0`, `0.0`, `0.00`, etc.).
186fn find_division_position(source: &str) -> (usize, usize) {
187    let bytes = source.as_bytes();
188    let len = bytes.len();
189    let mut i = 0;
190
191    while i < len {
192        if bytes[i] == b'/' {
193            // Skip whitespace after the slash.
194            let mut j = i + 1;
195            while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
196                j += 1;
197            }
198            // Check if what follows is a zero literal.
199            if j < len && bytes[j].is_ascii_digit() {
200                let start = j;
201                // Collect the numeric token.
202                while j < len && (bytes[j].is_ascii_digit() || bytes[j] == b'.') {
203                    j += 1;
204                }
205                let token = std::str::from_utf8(&bytes[start..j]).unwrap_or("");
206                if token.parse::<f64>().map(|v| v == 0.0).unwrap_or(false) {
207                    return offset_to_line_col(source, i);
208                }
209            }
210        }
211        i += 1;
212    }
213
214    (1, 1)
215}
216
217/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
218fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
219    let before = &source[..offset];
220    let line = before.chars().filter(|&c| c == '\n').count() + 1;
221    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
222    (line, col)
223}