Skip to main content

sqrust_rules/ambiguous/
nulls_ordering.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NullsOrdering;
4
5impl Rule for NullsOrdering {
6    fn name(&self) -> &'static str {
7        "Ambiguous/NullsOrdering"
8    }
9
10    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11        let source = &ctx.source;
12        let positions = find_order_by_positions(source);
13        let mut diags = Vec::new();
14
15        for order_by_offset in positions {
16            // Scan forward from the ORDER BY position to the next `;` or end of source.
17            let region_end = source[order_by_offset..]
18                .find(';')
19                .map(|rel| order_by_offset + rel)
20                .unwrap_or(source.len());
21            let region = &source[order_by_offset..region_end];
22
23            // Check whether `NULLS` appears in this region (word boundary, case-insensitive).
24            if !contains_nulls_keyword(region) {
25                let (line, col) = offset_to_line_col(source, order_by_offset);
26                diags.push(Diagnostic {
27                    rule: self.name(),
28                    message: "ORDER BY without NULLS FIRST/NULLS LAST is ambiguous; NULL sort order varies by database".to_string(),
29                    line,
30                    col,
31                });
32            }
33        }
34
35        diags
36    }
37}
38
39/// Finds all byte offsets where `ORDER BY` appears outside string literals,
40/// with a word boundary before the `O`.
41fn find_order_by_positions(source: &str) -> Vec<usize> {
42    let bytes = source.as_bytes();
43    let upper = source.to_ascii_uppercase();
44    let upper_bytes = upper.as_bytes();
45    let mut positions = Vec::new();
46    let mut in_string = false;
47    let mut i = 0;
48
49    while i < bytes.len() {
50        // Handle single-quoted string literals (SQL strings).
51        // Escaped quote inside a string: two consecutive single quotes `''`.
52        if !in_string && bytes[i] == b'\'' {
53            in_string = true;
54            i += 1;
55            continue;
56        }
57        if in_string {
58            if bytes[i] == b'\'' {
59                // Peek ahead for escaped quote `''`.
60                if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
61                    i += 2;
62                    continue;
63                }
64                in_string = false;
65            }
66            i += 1;
67            continue;
68        }
69
70        // Try to match `ORDER BY` at a word boundary.
71        // "ORDER BY" is 8 characters.
72        if i + 8 <= upper_bytes.len() && &upper_bytes[i..i + 8] == b"ORDER BY" {
73            let before_ok =
74                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
75            if before_ok {
76                positions.push(i);
77                i += 8;
78                continue;
79            }
80        }
81
82        i += 1;
83    }
84
85    positions
86}
87
88/// Returns true if `region` contains the word `NULLS` (case-insensitive, word boundary).
89fn contains_nulls_keyword(region: &str) -> bool {
90    let bytes = region.as_bytes();
91    let upper = region.to_ascii_uppercase();
92    let upper_bytes = upper.as_bytes();
93    let kw = b"NULLS";
94    let kw_len = kw.len();
95
96    let mut i = 0;
97    while i + kw_len <= upper_bytes.len() {
98        if &upper_bytes[i..i + kw_len] == kw {
99            let before_ok =
100                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
101            let after = i + kw_len;
102            let after_ok = after >= bytes.len()
103                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
104            if before_ok && after_ok {
105                return true;
106            }
107        }
108        i += 1;
109    }
110    false
111}
112
113/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
114fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
115    let before = &source[..offset];
116    let line = before.chars().filter(|&c| c == '\n').count() + 1;
117    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
118    (line, col)
119}