Skip to main content

sqrust_rules/ambiguous/
date_arithmetic.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct DateArithmetic;
4
5impl Rule for DateArithmetic {
6    fn name(&self) -> &'static str {
7        "Ambiguous/DateArithmetic"
8    }
9
10    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11        if !ctx.parse_errors.is_empty() {
12            return Vec::new();
13        }
14
15        let source = &ctx.source;
16        find_date_arithmetic_violations(source, ctx)
17    }
18}
19
20/// Prefixes that indicate the identifier is a date/time column when it starts with them.
21/// e.g. `date_col`, `timestamp_field`, `created_at`, `updated_on`, `ts_start`.
22const DATE_PREFIXES: &[&str] = &[
23    "date", "time", "timestamp", "ts", "created", "updated", "modified",
24];
25
26/// Suffixes (after an underscore) that indicate the identifier is a date/time column.
27/// e.g. `order_date`, `event_time`, `created_at`, `updated_on`.
28const DATE_SUFFIXES: &[&str] = &[
29    "date", "time", "timestamp", "ts", "at", "on", "created", "updated", "modified",
30];
31
32/// Returns `true` if the identifier token looks like a date/time column.
33/// Matches when the identifier starts with a date prefix (e.g. `date_col`, `ts_start`)
34/// or ends with a date suffix after an underscore (e.g. `created_at`, `order_date`).
35/// Does NOT match when the date hint appears only in the middle (e.g. `non_date_col`).
36fn is_date_like_identifier(token: &str) -> bool {
37    let lower = token.to_ascii_lowercase();
38
39    // Check prefixes: identifier starts with the prefix and is followed by _ or is the whole token.
40    for prefix in DATE_PREFIXES {
41        if lower == *prefix {
42            return true;
43        }
44        let prefixed = format!("{}_", prefix);
45        if lower.starts_with(&prefixed) {
46            return true;
47        }
48    }
49
50    // Check suffixes: identifier ends with _suffix.
51    for suffix in DATE_SUFFIXES {
52        let suffixed = format!("_{}", suffix);
53        if lower.ends_with(&suffixed) {
54            return true;
55        }
56    }
57
58    false
59}
60
61/// Returns `true` if the token is a plain (non-negative) integer literal.
62fn is_integer_token(token: &str) -> bool {
63    !token.is_empty() && token.bytes().all(|b| b.is_ascii_digit())
64}
65
66/// Returns `true` if `b` is a valid identifier character (alphanumeric or underscore).
67fn is_ident_char(b: u8) -> bool {
68    b.is_ascii_alphanumeric() || b == b'_'
69}
70
71/// Scans `source` skipping string literals and comments, finding tokens around
72/// `+` and `-` operators. Flags patterns where one side is a date-hint identifier
73/// and the other side is a plain integer literal.
74fn find_date_arithmetic_violations(source: &str, _ctx: &FileContext) -> Vec<Diagnostic> {
75    let mut diags = Vec::new();
76    let bytes = source.as_bytes();
77    let len = bytes.len();
78    let mut i = 0;
79
80    while i < len {
81        // Skip string literals delimited by single quotes.
82        if bytes[i] == b'\'' {
83            i += 1;
84            while i < len && bytes[i] != b'\'' {
85                if bytes[i] == b'\\' {
86                    i += 1;
87                }
88                i += 1;
89            }
90            i += 1; // skip closing quote
91            continue;
92        }
93
94        // Skip line comments (-- to end of line).
95        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
96            while i < len && bytes[i] != b'\n' {
97                i += 1;
98            }
99            continue;
100        }
101
102        // Skip block comments (/* ... */).
103        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
104            i += 2;
105            while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
106                i += 1;
107            }
108            i += 2; // skip closing */
109            continue;
110        }
111
112        // Check for + or - operator.
113        if bytes[i] == b'+' || bytes[i] == b'-' {
114            let op_pos = i;
115
116            // Look backwards for the left token (skip whitespace).
117            let left_end = scan_back_skip_whitespace(bytes, op_pos);
118            let left_token = extract_token_backwards(bytes, left_end);
119
120            // Look forwards for the right token (skip whitespace).
121            let right_start = scan_forward_skip_whitespace(bytes, op_pos + 1);
122            let right_token = extract_token_forwards(bytes, right_start);
123
124            // Flag if (date_like_identifier op integer) or (integer op date_like_identifier).
125            let should_flag = (!left_token.is_empty()
126                && !right_token.is_empty()
127                && is_date_like_identifier(&left_token)
128                && is_integer_token(&right_token))
129                || (!left_token.is_empty()
130                    && !right_token.is_empty()
131                    && is_integer_token(&left_token)
132                    && is_date_like_identifier(&right_token));
133
134            if should_flag {
135                let (line, col) = offset_to_line_col(source, op_pos);
136                diags.push(Diagnostic {
137                    rule: "Ambiguous/DateArithmetic",
138                    message: "Date arithmetic with integer offset is database-specific \
139                               — use INTERVAL '1' DAY or dialect-specific functions for portability"
140                        .to_string(),
141                    line,
142                    col,
143                });
144            }
145        }
146
147        i += 1;
148    }
149
150    diags
151}
152
153/// Returns the index of the last non-whitespace byte before `pos` (exclusive).
154/// Returns `pos` itself if nothing meaningful is before it.
155fn scan_back_skip_whitespace(bytes: &[u8], pos: usize) -> usize {
156    if pos == 0 {
157        return 0;
158    }
159    let mut j = pos - 1;
160    while j > 0 && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r') {
161        j -= 1;
162    }
163    j
164}
165
166/// Returns the index of the first non-whitespace byte at or after `pos`.
167fn scan_forward_skip_whitespace(bytes: &[u8], pos: usize) -> usize {
168    let mut j = pos;
169    while j < bytes.len()
170        && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r')
171    {
172        j += 1;
173    }
174    j
175}
176
177/// Extracts the identifier or integer token ending at byte index `end_pos` (inclusive).
178/// Works backwards from `end_pos` to find the token start.
179fn extract_token_backwards(bytes: &[u8], end_pos: usize) -> String {
180    if end_pos >= bytes.len() || !is_ident_char(bytes[end_pos]) {
181        return String::new();
182    }
183    let mut start = end_pos;
184    while start > 0 && is_ident_char(bytes[start - 1]) {
185        start -= 1;
186    }
187    String::from_utf8_lossy(&bytes[start..=end_pos]).into_owned()
188}
189
190/// Extracts the identifier or integer token starting at byte index `start_pos`.
191fn extract_token_forwards(bytes: &[u8], start_pos: usize) -> String {
192    if start_pos >= bytes.len() || !is_ident_char(bytes[start_pos]) {
193        return String::new();
194    }
195    let mut end = start_pos;
196    while end < bytes.len() && is_ident_char(bytes[end]) {
197        end += 1;
198    }
199    String::from_utf8_lossy(&bytes[start_pos..end]).into_owned()
200}
201
202/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
203fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
204    let before = &source[..offset];
205    let line = before.chars().filter(|&c| c == '\n').count() + 1;
206    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
207    (line, col)
208}