Skip to main content

sqrust_rules/lint/
duplicate_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Query, Select, SetExpr, Statement, TableFactor, TableWithJoins, With};
3use std::collections::{HashMap, HashSet};
4
5pub struct DuplicateJoin;
6
7impl Rule for DuplicateJoin {
8    fn name(&self) -> &'static str {
9        "Lint/DuplicateJoin"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            check_stmt(stmt, &ctx.source, "Lint/DuplicateJoin", &mut diags);
19        }
20        diags
21    }
22}
23
24fn check_stmt(stmt: &Statement, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
25    if let Statement::Query(q) = stmt {
26        check_query(q, src, rule, diags);
27    }
28}
29
30fn check_query(q: &Query, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
31    if let Some(With { cte_tables, .. }) = &q.with {
32        for cte in cte_tables {
33            check_query(&cte.query, src, rule, diags);
34        }
35    }
36    check_set_expr(&q.body, src, rule, diags);
37}
38
39fn check_set_expr(body: &SetExpr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
40    match body {
41        SetExpr::Select(s) => check_select(s, src, rule, diags),
42        SetExpr::SetOperation { left, right, .. } => {
43            check_set_expr(left, src, rule, diags);
44            check_set_expr(right, src, rule, diags);
45        }
46        SetExpr::Query(q) => check_query(q, src, rule, diags),
47        _ => {}
48    }
49}
50
51fn check_select(sel: &Select, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
52    for twj in &sel.from {
53        check_table_with_joins(twj, src, rule, diags);
54    }
55}
56
57fn check_table_with_joins(
58    twj: &TableWithJoins,
59    src: &str,
60    rule: &'static str,
61    diags: &mut Vec<Diagnostic>,
62) {
63    // Collect all table names (lowercased full name) with first occurrence offset
64    let mut seen: HashMap<String, usize> = HashMap::new();
65    let mut already_flagged: HashSet<String> = HashSet::new();
66    let mut last_off: usize = 0;
67
68    // Main table
69    if let Some((name, off)) = table_factor_name(&twj.relation, src, last_off) {
70        last_off = off + 1;
71        seen.insert(name, off);
72    }
73
74    // Recurse into subqueries in main table
75    check_factor_subqueries(&twj.relation, src, rule, diags);
76
77    // JOINs
78    for join in &twj.joins {
79        check_factor_subqueries(&join.relation, src, rule, diags);
80        if let Some((name, off)) = table_factor_name(&join.relation, src, last_off) {
81            last_off = off + 1;
82            if seen.contains_key(&name) && !already_flagged.contains(&name) {
83                let (line, col) = offset_to_line_col(src, off);
84                diags.push(Diagnostic {
85                    rule,
86                    message: format!(
87                        "Table '{}' is joined more than once in the same FROM clause",
88                        name
89                    ),
90                    line,
91                    col,
92                });
93                already_flagged.insert(name.clone());
94            } else if !seen.contains_key(&name) {
95                seen.insert(name, off);
96            }
97        }
98    }
99}
100
101fn table_factor_name(tf: &TableFactor, src: &str, start: usize) -> Option<(String, usize)> {
102    match tf {
103        TableFactor::Table { name, .. } => {
104            let full_name = name
105                .0
106                .iter()
107                .map(|i| i.value.to_lowercase())
108                .collect::<Vec<_>>()
109                .join(".");
110            let last = name.0.last()?.value.clone();
111            let off = find_word_in_source(src, &last, start)?;
112            Some((full_name, off))
113        }
114        _ => None,
115    }
116}
117
118fn check_factor_subqueries(
119    tf: &TableFactor,
120    src: &str,
121    rule: &'static str,
122    diags: &mut Vec<Diagnostic>,
123) {
124    if let TableFactor::Derived { subquery, .. } = tf {
125        check_query(subquery, src, rule, diags);
126    }
127}
128
129fn find_word_in_source(src: &str, word: &str, start: usize) -> Option<usize> {
130    let bytes = src.as_bytes();
131    let wbytes = word.as_bytes();
132    let wlen = wbytes.len();
133    if wlen == 0 {
134        return None;
135    }
136    let mut i = start;
137    while i + wlen <= bytes.len() {
138        if bytes[i..i + wlen].eq_ignore_ascii_case(wbytes) {
139            let before_ok = i == 0 || !is_wc(bytes[i - 1]);
140            let after_ok = i + wlen >= bytes.len() || !is_wc(bytes[i + wlen]);
141            if before_ok && after_ok {
142                return Some(i);
143            }
144        }
145        i += 1;
146    }
147    None
148}
149
150fn is_wc(b: u8) -> bool {
151    b.is_ascii_alphanumeric() || b == b'_'
152}
153
154fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
155    let before = &source[..offset.min(source.len())];
156    let line = before.chars().filter(|&c| c == '\n').count() + 1;
157    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
158    (line, col)
159}