Skip to main content

sqrust_rules/structure/
nested_subquery.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3use crate::capitalisation::{is_word_char, SkipMap};
4
5pub struct NestedSubquery {
6    /// Maximum number of subquery nesting levels allowed.
7    /// Queries with more `(SELECT` patterns than this are flagged.
8    pub max_depth: usize,
9}
10
11impl Default for NestedSubquery {
12    fn default() -> Self {
13        NestedSubquery { max_depth: 2 }
14    }
15}
16
17impl Rule for NestedSubquery {
18    fn name(&self) -> &'static str {
19        "Structure/NestedSubquery"
20    }
21
22    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
23        let source = &ctx.source;
24        let bytes = source.as_bytes();
25        let len = bytes.len();
26        let skip_map = SkipMap::build(source);
27
28        // Count occurrences of `(` followed by optional whitespace followed by
29        // the keyword SELECT (word-boundary on both sides).
30        //
31        // Each such pattern represents one level of subquery nesting.
32        // We record the byte offset of the SELECT keyword at which we first
33        // exceed max_depth so we can provide an accurate line/col.
34        let mut depth: usize = 0;
35        let mut first_excess_offset: Option<usize> = None;
36
37        let mut i = 0;
38        while i < len {
39            // Skip bytes that are inside strings, comments, or quoted identifiers.
40            if !skip_map.is_code(i) {
41                i += 1;
42                continue;
43            }
44
45            let b = bytes[i];
46
47            // Look for `(` in code.
48            if b == b'(' {
49                // Scan forward past optional whitespace to find SELECT.
50                let mut j = i + 1;
51                while j < len && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r') {
52                    j += 1;
53                }
54
55                // j now points at the first non-whitespace byte after `(`.
56                // Check whether it starts the keyword SELECT (case-insensitive,
57                // word-boundary after).
58                if j + 6 <= len {
59                    let candidate = &bytes[j..j + 6];
60                    let is_select = b"SELECT"
61                        .iter()
62                        .zip(candidate.iter())
63                        .all(|(a, b)| a.eq_ignore_ascii_case(b));
64
65                    // Word boundary after SELECT: next byte must not be alphanumeric/_
66                    let boundary_after = j + 6 >= len || {
67                        let nb = bytes[j + 6];
68                        !is_word_char(nb)
69                    };
70
71                    // All bytes of SELECT must be real code (not inside a skip region).
72                    let all_code = (j..j + 6).all(|k| skip_map.is_code(k));
73
74                    if is_select && boundary_after && all_code {
75                        depth += 1;
76                        if depth > self.max_depth && first_excess_offset.is_none() {
77                            first_excess_offset = Some(j);
78                        }
79                    }
80                }
81
82                i += 1;
83                continue;
84            }
85
86            i += 1;
87        }
88
89        if depth > self.max_depth {
90            let offset = first_excess_offset.unwrap_or(0);
91            let (line, col) = line_col(source, offset);
92            vec![Diagnostic {
93                rule: self.name(),
94                message: format!(
95                    "Subquery nesting depth {depth} exceeds maximum of {max}; consider using CTEs",
96                    depth = depth,
97                    max = self.max_depth,
98                ),
99                line,
100                col,
101            }]
102        } else {
103            Vec::new()
104        }
105    }
106}
107
108/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
109fn line_col(source: &str, offset: usize) -> (usize, usize) {
110    let before = &source[..offset];
111    let line = before.chars().filter(|&c| c == '\n').count() + 1;
112    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
113    (line, col)
114}