sqruff_lib/rules/references/
rf01.rs

1use std::cell::RefCell;
2
3use ahash::AHashMap;
4use itertools::Itertools;
5use smol_str::SmolStr;
6use sqruff_lib_core::dialects::Dialect;
7use sqruff_lib_core::dialects::common::AliasInfo;
8use sqruff_lib_core::dialects::init::DialectKind;
9use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
10use sqruff_lib_core::parser::segments::object_reference::{
11    ObjectReferenceLevel, ObjectReferencePart, ObjectReferenceSegment,
12};
13use sqruff_lib_core::utils::analysis::query::{Query, QueryInner, Selectable};
14
15use crate::core::config::Value;
16use crate::core::rules::context::RuleContext;
17use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
18use crate::core::rules::reference::object_ref_matches_table;
19use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
20
21#[derive(Debug, Default, Clone)]
22struct RF01QueryData {
23    aliases: Vec<AliasInfo>,
24    standalone_aliases: Vec<SmolStr>,
25}
26
27type QueryKey<'a> = *const RefCell<QueryInner<'a>>;
28type RF01State<'a> = AHashMap<QueryKey<'a>, RF01QueryData>;
29
30#[derive(Debug, Clone, Default)]
31pub struct RuleRF01 {
32    force_enable: bool,
33}
34
35impl RuleRF01 {
36    #[allow(clippy::only_used_in_recursion)]
37    fn resolve_reference<'a>(
38        &self,
39        r: &ObjectReferenceSegment,
40        tbl_refs: Vec<(ObjectReferencePart, Vec<SmolStr>)>,
41        dml_target_table: &[SmolStr],
42        query: Query<'a>,
43        payloads: &RF01State<'a>,
44    ) -> Option<LintResult> {
45        let possible_references: Vec<_> = tbl_refs
46            .clone()
47            .into_iter()
48            .map(|tbl_ref| tbl_ref.1)
49            .collect();
50
51        let mut targets = vec![];
52
53        if let Some(payload) = payloads.get(&query.id()) {
54            for alias in &payload.aliases {
55                if alias.aliased {
56                    targets.push(vec![alias.ref_str.clone()]);
57                }
58
59                if let Some(object_reference) = &alias.object_reference {
60                    let references = object_reference
61                        .reference()
62                        .iter_raw_references()
63                        .into_iter()
64                        .map(|it| it.part.into())
65                        .collect_vec();
66
67                    targets.push(references);
68                }
69            }
70
71            for standalone_alias in &payload.standalone_aliases {
72                targets.push(vec![standalone_alias.clone()]);
73            }
74        }
75
76        if !object_ref_matches_table(&possible_references, &targets) {
77            if let Some(parent) = RefCell::borrow(&query.inner).parent.clone() {
78                return self.resolve_reference(
79                    r,
80                    tbl_refs.clone(),
81                    dml_target_table,
82                    parent,
83                    payloads,
84                );
85            } else if dml_target_table.is_empty()
86                || !object_ref_matches_table(&possible_references, &[dml_target_table.to_vec()])
87            {
88                return LintResult::new(
89                    tbl_refs[0].0.segments[0].clone().into(),
90                    Vec::new(),
91                    format!(
92                        "Reference '{}' refers to table/view not found in the FROM clause or \
93                         found in ancestor statement.",
94                        r.0.raw()
95                    )
96                    .into(),
97                    None,
98                )
99                .into();
100            }
101        }
102        None
103    }
104
105    fn get_table_refs(
106        &self,
107        r: &ObjectReferenceSegment,
108        dialect: &Dialect,
109    ) -> Vec<(ObjectReferencePart, Vec<SmolStr>)> {
110        let mut tbl_refs = Vec::new();
111
112        for values in r.extract_possible_multipart_references(&[
113            ObjectReferenceLevel::Schema,
114            ObjectReferenceLevel::Table,
115        ]) {
116            tbl_refs.push((
117                values[1].clone(),
118                vec![values[0].part.clone().into(), values[1].part.clone().into()],
119            ));
120        }
121
122        if tbl_refs.is_empty() || dialect.name == DialectKind::Bigquery {
123            tbl_refs.extend(
124                r.extract_possible_references(ObjectReferenceLevel::Table, dialect.name)
125                    .into_iter()
126                    .map(|it| (it.clone(), vec![it.part.into()])),
127            );
128        }
129
130        tbl_refs
131    }
132
133    fn analyze_table_references<'a>(
134        &self,
135        query: Query<'a>,
136        dml_target_table: &[SmolStr],
137        payloads: &mut RF01State<'a>,
138        violations: &mut Vec<LintResult>,
139    ) {
140        payloads.entry(query.id()).or_default();
141        let selectables = std::mem::take(&mut RefCell::borrow_mut(&query.inner).selectables);
142
143        for selectable in &selectables {
144            if let Some(select_info) = selectable.select_info() {
145                let table_aliases = select_info.table_aliases;
146                let standalone_aliases = select_info.standalone_aliases;
147                let reference_buffer = select_info.reference_buffer;
148
149                {
150                    let payload = payloads.entry(query.id()).or_default();
151                    payload.aliases.extend(table_aliases);
152                    payload.standalone_aliases.extend(standalone_aliases);
153                }
154
155                for r in reference_buffer {
156                    if !self.should_ignore_reference(&r, selectable) {
157                        let violation = self.resolve_reference(
158                            &r,
159                            self.get_table_refs(&r, RefCell::borrow(&query.inner).dialect),
160                            dml_target_table,
161                            query.clone(),
162                            payloads,
163                        );
164                        violations.extend(violation);
165                    }
166                }
167            }
168        }
169
170        RefCell::borrow_mut(&query.inner).selectables = selectables;
171
172        for child in query.children() {
173            self.analyze_table_references(child, dml_target_table, payloads, violations);
174        }
175    }
176
177    fn should_ignore_reference(
178        &self,
179        reference: &ObjectReferenceSegment,
180        selectable: &Selectable,
181    ) -> bool {
182        let ref_path = selectable.selectable.path_to(&reference.0);
183
184        if !ref_path.is_empty() {
185            ref_path
186                .iter()
187                .any(|ps| ps.segment.is_type(SyntaxKind::IntoTableClause))
188        } else {
189            false
190        }
191    }
192}
193
194impl Rule for RuleRF01 {
195    fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
196        Ok(RuleRF01 {
197            force_enable: config["force_enable"].as_bool().unwrap(),
198        }
199        .erased())
200    }
201
202    fn name(&self) -> &'static str {
203        "references.from"
204    }
205
206    fn description(&self) -> &'static str {
207        "References cannot reference objects not present in 'FROM' clause."
208    }
209
210    fn long_description(&self) -> &'static str {
211        r#"
212**Anti-pattern**
213
214In this example, the reference `vee` has not been declared.
215
216```sql
217SELECT
218    vee.a
219FROM foo
220```
221
222**Best practice**
223
224Remove the reference.
225
226```sql
227SELECT
228    a
229FROM foo
230```
231"#
232    }
233
234    fn groups(&self) -> &'static [RuleGroups] {
235        &[RuleGroups::All, RuleGroups::Core, RuleGroups::References]
236    }
237
238    fn force_enable(&self) -> bool {
239        self.force_enable
240    }
241
242    fn dialect_skip(&self) -> &'static [DialectKind] {
243        // TODO Add others when finished, whole list["databricks", "hive", "soql"]
244        &[
245            DialectKind::Redshift,
246            DialectKind::Bigquery,
247            DialectKind::Sparksql,
248        ]
249    }
250
251    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
252        let query = Query::from_segment(&context.segment, context.dialect, None);
253        let mut payloads = RF01State::default();
254        let mut violations = Vec::new();
255        let tmp;
256
257        let dml_target_table = if !context.segment.is_type(SyntaxKind::SelectStatement) {
258            let refs = context.segment.recursive_crawl(
259                const { &SyntaxSet::new(&[SyntaxKind::TableReference]) },
260                true,
261                &SyntaxSet::EMPTY,
262                true,
263            );
264            if let Some(reference) = refs.first() {
265                let reference = reference.reference();
266
267                tmp = reference
268                    .iter_raw_references()
269                    .into_iter()
270                    .map(|it| it.part.into())
271                    .collect_vec();
272                &tmp
273            } else {
274                [].as_slice()
275            }
276        } else {
277            &[]
278        };
279
280        self.analyze_table_references(query, dml_target_table, &mut payloads, &mut violations);
281
282        violations
283    }
284
285    fn crawl_behaviour(&self) -> Crawler {
286        SegmentSeekerCrawler::new(
287            const {
288                SyntaxSet::new(&[
289                    SyntaxKind::DeleteStatement,
290                    SyntaxKind::MergeStatement,
291                    SyntaxKind::SelectStatement,
292                    SyntaxKind::UpdateStatement,
293                ])
294            },
295        )
296        .disallow_recurse()
297        .into()
298    }
299}