Skip to main content

sqrust_rules/lint/
cross_database_reference.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Delete, FromTable, Query, Select, SetExpr, Statement, TableFactor, TableWithJoins,
4};
5
6pub struct CrossDatabaseReference;
7
8impl Rule for CrossDatabaseReference {
9    fn name(&self) -> &'static str {
10        "Lint/CrossDatabaseReference"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        // Skip files that failed to parse — AST may be incomplete.
15        if !ctx.parse_errors.is_empty() {
16            return Vec::new();
17        }
18
19        let mut diags = Vec::new();
20
21        for stmt in &ctx.statements {
22            match stmt {
23                Statement::Query(q) => {
24                    check_query(q, ctx, &mut diags);
25                }
26                Statement::Insert(insert) => {
27                    // Check the target table of INSERT INTO
28                    if insert.table_name.0.len() >= 3 {
29                        let name_str = insert
30                            .table_name
31                            .0
32                            .iter()
33                            .map(|i| i.value.as_str())
34                            .collect::<Vec<_>>()
35                            .join(".");
36                        let (line, col) =
37                            find_name_position(ctx.source.as_str(), &name_str);
38                        diags.push(make_diagnostic(ctx, name_str, line, col));
39                    }
40                }
41                Statement::Update {
42                    table: TableWithJoins { relation, joins },
43                    ..
44                } => {
45                    check_table_factor(relation, ctx, &mut diags);
46                    for join in joins {
47                        check_table_factor(&join.relation, ctx, &mut diags);
48                    }
49                }
50                Statement::Delete(Delete { from, .. }) => {
51                    let tables = match from {
52                        FromTable::WithFromKeyword(v) | FromTable::WithoutKeyword(v) => v,
53                    };
54                    for twj in tables {
55                        check_table_with_joins(twj, ctx, &mut diags);
56                    }
57                }
58                _ => {}
59            }
60        }
61
62        diags
63    }
64}
65
66fn check_query(q: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
67    if let Some(with) = &q.with {
68        for cte in &with.cte_tables {
69            check_query(&cte.query, ctx, diags);
70        }
71    }
72    check_set_expr(&q.body, ctx, diags);
73}
74
75fn check_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
76    match expr {
77        SetExpr::Select(sel) => check_select(sel, ctx, diags),
78        SetExpr::Query(q) => check_query(q, ctx, diags),
79        SetExpr::SetOperation { left, right, .. } => {
80            check_set_expr(left, ctx, diags);
81            check_set_expr(right, ctx, diags);
82        }
83        _ => {}
84    }
85}
86
87fn check_select(sel: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
88    for twj in &sel.from {
89        check_table_with_joins(twj, ctx, diags);
90    }
91}
92
93fn check_table_with_joins(
94    twj: &TableWithJoins,
95    ctx: &FileContext,
96    diags: &mut Vec<Diagnostic>,
97) {
98    check_table_factor(&twj.relation, ctx, diags);
99    for join in &twj.joins {
100        check_table_factor(&join.relation, ctx, diags);
101    }
102}
103
104fn check_table_factor(tf: &TableFactor, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
105    if let TableFactor::Table { name, .. } = tf {
106        if name.0.len() >= 3 {
107            let name_str = name
108                .0
109                .iter()
110                .map(|i| i.value.as_str())
111                .collect::<Vec<_>>()
112                .join(".");
113            let (line, col) = find_name_position(ctx.source.as_str(), &name_str);
114            diags.push(make_diagnostic(ctx, name_str, line, col));
115        }
116    }
117}
118
119fn make_diagnostic(
120    _ctx: &FileContext,
121    name_str: String,
122    line: usize,
123    col: usize,
124) -> Diagnostic {
125    Diagnostic {
126        rule: "Lint/CrossDatabaseReference",
127        message: format!(
128            "Cross-database table reference '{}' — in dbt, use ref() or source() macros \
129             instead of hardcoded cross-database paths",
130            name_str
131        ),
132        line,
133        col,
134    }
135}
136
137/// Finds the 1-indexed (line, col) of the first occurrence of `name` in `source`.
138/// Falls back to (1, 1) if not found.
139fn find_name_position(source: &str, name: &str) -> (usize, usize) {
140    let source_upper = source.to_uppercase();
141    let name_upper = name.to_uppercase();
142    if let Some(pos) = source_upper.find(&name_upper) {
143        return offset_to_line_col(source, pos);
144    }
145    (1, 1)
146}
147
148/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
149fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
150    let before = &source[..offset.min(source.len())];
151    let line = before.chars().filter(|&c| c == '\n').count() + 1;
152    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
153    (line, col)
154}