Skip to main content

sqrust_rules/lint/
update_set_duplicate.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{AssignmentTarget, ObjectName, Statement};
3use std::collections::HashMap;
4
5pub struct UpdateSetDuplicate;
6
7impl Rule for UpdateSetDuplicate {
8    fn name(&self) -> &'static str {
9        "Lint/UpdateSetDuplicate"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        // Skip files that failed to parse — AST may be incomplete.
14        if !ctx.parse_errors.is_empty() {
15            return Vec::new();
16        }
17
18        let mut diags = Vec::new();
19
20        for stmt in &ctx.statements {
21            if let Statement::Update { assignments, .. } = stmt {
22                // Count occurrences of each column name (lowercased).
23                let mut counts: HashMap<String, usize> = HashMap::new();
24                // Preserve insertion order for deterministic diagnostics.
25                let mut order: Vec<String> = Vec::new();
26
27                for assignment in assignments {
28                    if let AssignmentTarget::ColumnName(col_name) = &assignment.target {
29                        let name = extract_column_name(col_name);
30                        let lower = name.to_lowercase();
31                        let entry = counts.entry(lower.clone()).or_insert(0);
32                        *entry += 1;
33                        if *entry == 1 {
34                            order.push(lower);
35                        }
36                    }
37                }
38
39                // Emit one diagnostic per duplicated column name.
40                for col_lower in &order {
41                    if counts[col_lower] > 1 {
42                        // Find the position of the second occurrence of `col = ` in source.
43                        let (line, col) = find_second_occurrence(&ctx.source, col_lower);
44                        diags.push(Diagnostic {
45                            rule: self.name(),
46                            message: format!(
47                                "Column '{}' appears more than once in UPDATE SET clause",
48                                col_lower
49                            ),
50                            line,
51                            col,
52                        });
53                    }
54                }
55            }
56        }
57
58        diags
59    }
60}
61
62/// Extracts the column name from an ObjectName (last ident, preserving original case).
63fn extract_column_name(obj: &ObjectName) -> &str {
64    obj.0
65        .last()
66        .map(|id| id.value.as_str())
67        .unwrap_or("")
68}
69
70/// Finds the 1-indexed (line, col) of the second occurrence of
71/// `col_name\s*=` (case-insensitive) outside strings/comments in `source`.
72/// Falls back to (1, 1) if fewer than two occurrences are found.
73fn find_second_occurrence(source: &str, col_name: &str) -> (usize, usize) {
74    let source_lower = source.to_lowercase();
75    let name_lower = col_name.to_lowercase();
76    let name_len = name_lower.len();
77    let bytes = source_lower.as_bytes();
78    let len = bytes.len();
79
80    let mut found = 0usize;
81    let mut search_from = 0usize;
82
83    while search_from < len {
84        let Some(rel) = source_lower[search_from..].find(&name_lower) else {
85            break;
86        };
87        let abs = search_from + rel;
88
89        // Word boundary check.
90        let before_ok = abs == 0 || {
91            let b = bytes[abs - 1];
92            !b.is_ascii_alphanumeric() && b != b'_'
93        };
94        let after_name = abs + name_len;
95        let after_ok = after_name >= len || {
96            let b = bytes[after_name];
97            !b.is_ascii_alphanumeric() && b != b'_'
98        };
99
100        if before_ok && after_ok {
101            // Verify that `=` follows (optionally with whitespace), indicating an assignment.
102            let after_ws = skip_whitespace_in(bytes, after_name, len);
103            let is_assignment = after_ws < len && bytes[after_ws] == b'=';
104
105            if is_assignment {
106                found += 1;
107                if found == 2 {
108                    return offset_to_line_col(source, abs);
109                }
110            }
111        }
112
113        search_from = abs + 1;
114    }
115
116    (1, 1)
117}
118
119fn skip_whitespace_in(bytes: &[u8], mut pos: usize, len: usize) -> usize {
120    while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\t' || bytes[pos] == b'\n' || bytes[pos] == b'\r') {
121        pos += 1;
122    }
123    pos
124}
125
126/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
127fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
128    let before = &source[..offset];
129    let line = before.chars().filter(|&c| c == '\n').count() + 1;
130    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
131    (line, col)
132}