sqrust_rules/lint/
update_set_duplicate.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{AssignmentTarget, ObjectName, Statement};
3use std::collections::HashMap;
4
5pub struct UpdateSetDuplicate;
6
7impl Rule for UpdateSetDuplicate {
8 fn name(&self) -> &'static str {
9 "Lint/UpdateSetDuplicate"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
15 return Vec::new();
16 }
17
18 let mut diags = Vec::new();
19
20 for stmt in &ctx.statements {
21 if let Statement::Update { assignments, .. } = stmt {
22 let mut counts: HashMap<String, usize> = HashMap::new();
24 let mut order: Vec<String> = Vec::new();
26
27 for assignment in assignments {
28 if let AssignmentTarget::ColumnName(col_name) = &assignment.target {
29 let name = extract_column_name(col_name);
30 let lower = name.to_lowercase();
31 let entry = counts.entry(lower.clone()).or_insert(0);
32 *entry += 1;
33 if *entry == 1 {
34 order.push(lower);
35 }
36 }
37 }
38
39 for col_lower in &order {
41 if counts[col_lower] > 1 {
42 let (line, col) = find_second_occurrence(&ctx.source, col_lower);
44 diags.push(Diagnostic {
45 rule: self.name(),
46 message: format!(
47 "Column '{}' appears more than once in UPDATE SET clause",
48 col_lower
49 ),
50 line,
51 col,
52 });
53 }
54 }
55 }
56 }
57
58 diags
59 }
60}
61
62fn extract_column_name(obj: &ObjectName) -> &str {
64 obj.0
65 .last()
66 .map(|id| id.value.as_str())
67 .unwrap_or("")
68}
69
70fn find_second_occurrence(source: &str, col_name: &str) -> (usize, usize) {
74 let source_lower = source.to_lowercase();
75 let name_lower = col_name.to_lowercase();
76 let name_len = name_lower.len();
77 let bytes = source_lower.as_bytes();
78 let len = bytes.len();
79
80 let mut found = 0usize;
81 let mut search_from = 0usize;
82
83 while search_from < len {
84 let Some(rel) = source_lower[search_from..].find(&name_lower) else {
85 break;
86 };
87 let abs = search_from + rel;
88
89 let before_ok = abs == 0 || {
91 let b = bytes[abs - 1];
92 !b.is_ascii_alphanumeric() && b != b'_'
93 };
94 let after_name = abs + name_len;
95 let after_ok = after_name >= len || {
96 let b = bytes[after_name];
97 !b.is_ascii_alphanumeric() && b != b'_'
98 };
99
100 if before_ok && after_ok {
101 let after_ws = skip_whitespace_in(bytes, after_name, len);
103 let is_assignment = after_ws < len && bytes[after_ws] == b'=';
104
105 if is_assignment {
106 found += 1;
107 if found == 2 {
108 return offset_to_line_col(source, abs);
109 }
110 }
111 }
112
113 search_from = abs + 1;
114 }
115
116 (1, 1)
117}
118
119fn skip_whitespace_in(bytes: &[u8], mut pos: usize, len: usize) -> usize {
120 while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\t' || bytes[pos] == b'\n' || bytes[pos] == b'\r') {
121 pos += 1;
122 }
123 pos
124}
125
126fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
128 let before = &source[..offset];
129 let line = before.chars().filter(|&c| c == '\n').count() + 1;
130 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
131 (line, col)
132}