Skip to main content

sqrust_rules/convention/
nvl_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NvlFunction;
4
5const MESSAGE_NVL: &str =
6    "NVL() is Oracle-specific; use COALESCE() for standard SQL";
7
8const MESSAGE_NVL2: &str =
9    "NVL2() is Oracle-specific; use CASE WHEN col IS NOT NULL THEN ... ELSE ... END instead";
10
11impl Rule for NvlFunction {
12    fn name(&self) -> &'static str {
13        "Convention/NvlFunction"
14    }
15
16    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
17        find_violations(&ctx.source, self.name())
18    }
19}
20
21fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
22    let mut skip = std::collections::HashSet::new();
23    let bytes = source.as_bytes();
24    let len = bytes.len();
25    let mut i = 0;
26    while i < len {
27        if bytes[i] == b'\'' {
28            i += 1;
29            while i < len {
30                if bytes[i] == b'\'' {
31                    if i + 1 < len && bytes[i + 1] == b'\'' {
32                        skip.insert(i);
33                        i += 2;
34                    } else {
35                        i += 1;
36                        break;
37                    }
38                } else {
39                    skip.insert(i);
40                    i += 1;
41                }
42            }
43        } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
44            while i < len && bytes[i] != b'\n' {
45                skip.insert(i);
46                i += 1;
47            }
48        } else {
49            i += 1;
50        }
51    }
52    skip
53}
54
55/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
56fn line_col(source: &str, offset: usize) -> (usize, usize) {
57    let before = &source[..offset];
58    let line = before.chars().filter(|&c| c == '\n').count() + 1;
59    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
60    (line, col)
61}
62
63#[inline]
64fn is_word_char(ch: u8) -> bool {
65    ch.is_ascii_alphanumeric() || ch == b'_'
66}
67
68fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
69    let bytes = source.as_bytes();
70    let len = bytes.len();
71
72    if len == 0 {
73        return Vec::new();
74    }
75
76    let skip = build_skip_set(source);
77    let mut diags = Vec::new();
78
79    // We search for both NVL2( (4 chars + 1 paren) and NVL( (3 chars + 1 paren).
80    // Try NVL2 first at each position, then fall back to NVL.
81    let nvl2 = b"NVL2";
82    let nvl = b"NVL";
83
84    let mut i = 0;
85    while i < len {
86        if skip.contains(&i) {
87            i += 1;
88            continue;
89        }
90
91        // Word boundary before
92        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
93        if !before_ok {
94            i += 1;
95            continue;
96        }
97
98        // Try to match NVL2 first (4 chars)
99        if i + nvl2.len() <= len
100            && bytes[i..i + nvl2.len()].eq_ignore_ascii_case(nvl2)
101        {
102            // Ensure all keyword bytes are code
103            let all_code = (0..nvl2.len()).all(|k| !skip.contains(&(i + k)));
104            if all_code {
105                let kw_end = i + nvl2.len();
106                // Must be followed by '(' and NOT followed by a word char (avoid NVL2X)
107                let after_ok = kw_end < len
108                    && bytes[kw_end] == b'('
109                    && (kw_end + 1 >= len || !is_word_char(bytes[kw_end]));
110                // The '(' itself already ensures it's not a longer word
111                if kw_end < len && bytes[kw_end] == b'(' {
112                    let (line, col) = line_col(source, i);
113                    diags.push(Diagnostic {
114                        rule: rule_name,
115                        message: MESSAGE_NVL2.to_string(),
116                        line,
117                        col,
118                    });
119                    i = kw_end + 1;
120                    let _ = after_ok;
121                    continue;
122                }
123            }
124        }
125
126        // Try to match NVL (3 chars), but ensure it's not NVL2 (word boundary after NVL)
127        if i + nvl.len() <= len
128            && bytes[i..i + nvl.len()].eq_ignore_ascii_case(nvl)
129        {
130            // Ensure all keyword bytes are code
131            let all_code = (0..nvl.len()).all(|k| !skip.contains(&(i + k)));
132            if all_code {
133                let kw_end = i + nvl.len();
134                // Must be immediately followed by '(' (not NVL2, NVLx, etc.)
135                if kw_end < len && bytes[kw_end] == b'(' {
136                    let (line, col) = line_col(source, i);
137                    diags.push(Diagnostic {
138                        rule: rule_name,
139                        message: MESSAGE_NVL.to_string(),
140                        line,
141                        col,
142                    });
143                    i = kw_end + 1;
144                    continue;
145                }
146            }
147        }
148
149        i += 1;
150    }
151
152    diags
153}