Skip to main content

sqruff_lib/rules/references/
rf03.rs

1use std::cell::RefCell;
2
3use hashbrown::{HashMap, HashSet};
4use itertools::Itertools;
5use smol_str::SmolStr;
6use sqruff_lib_core::dialects::common::{AliasInfo, ColumnAliasInfo};
7use sqruff_lib_core::dialects::init::DialectKind;
8use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
9use sqruff_lib_core::helpers::capitalize;
10use sqruff_lib_core::lint_fix::LintFix;
11use sqruff_lib_core::parser::segments::object_reference::ObjectReferenceSegment;
12use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder, Tables};
13use sqruff_lib_core::utils::analysis::query::Query;
14
15use crate::core::config::Value;
16use crate::core::rules::context::RuleContext;
17use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
18use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
19
20#[derive(Debug, Clone, Default)]
21pub struct RuleRF03 {
22    single_table_references: Option<String>,
23    force_enable: bool,
24}
25
26impl RuleRF03 {
27    fn visit_queries(
28        tables: &Tables,
29        single_table_references: &str,
30        is_struct_dialect: bool,
31        query: Query<'_>,
32        _visited: &mut HashSet<ErasedSegment>,
33    ) -> Vec<LintResult> {
34        #[allow(unused_assignments)]
35        let mut select_info = None;
36
37        let mut acc = Vec::new();
38        let selectables = &RefCell::borrow(&query.inner).selectables;
39
40        if !selectables.is_empty() {
41            select_info = selectables[0].select_info();
42
43            if let Some(select_info) = select_info
44                .clone()
45                .filter(|select_info| select_info.table_aliases.len() == 1)
46            {
47                let mut fixable = true;
48                let possible_ref_tables = iter_available_targets(query.clone());
49
50                if let Some(_parent) = &RefCell::borrow(&query.inner).parent {}
51
52                if possible_ref_tables.len() > 1 {
53                    fixable = false;
54                }
55
56                let results = check_references(
57                    tables,
58                    select_info.table_aliases,
59                    select_info.standalone_aliases,
60                    select_info.reference_buffer,
61                    select_info.col_aliases,
62                    single_table_references,
63                    is_struct_dialect,
64                    Some("qualified".into()),
65                    fixable,
66                );
67
68                acc.extend(results);
69            }
70        }
71
72        let children = query.children();
73        for child in children {
74            acc.extend(Self::visit_queries(
75                tables,
76                single_table_references,
77                is_struct_dialect,
78                child,
79                _visited,
80            ));
81        }
82
83        acc
84    }
85}
86
87fn iter_available_targets(query: Query<'_>) -> Vec<SmolStr> {
88    RefCell::borrow(&query.inner)
89        .selectables
90        .iter()
91        .flat_map(|selectable| {
92            selectable
93                .select_info()
94                .unwrap()
95                .table_aliases
96                .iter()
97                .map(|alias| alias.ref_str.clone())
98                .collect_vec()
99        })
100        .collect_vec()
101}
102
103#[allow(clippy::too_many_arguments)]
104fn check_references(
105    tables: &Tables,
106    table_aliases: Vec<AliasInfo>,
107    standalone_aliases: Vec<SmolStr>,
108    references: Vec<ObjectReferenceSegment>,
109    col_aliases: Vec<ColumnAliasInfo>,
110    single_table_references: &str,
111    is_struct_dialect: bool,
112    fix_inconsistent_to: Option<String>,
113    fixable: bool,
114) -> Vec<LintResult> {
115    let mut acc = Vec::new();
116
117    let col_alias_names = col_aliases
118        .clone()
119        .into_iter()
120        .map(|it| it.alias_identifier_name)
121        .collect_vec();
122
123    let table_ref_str = &table_aliases[0].ref_str;
124    let table_ref_str_source = table_aliases[0].segment.clone();
125    let mut seen_ref_types = HashSet::new();
126
127    for reference in references.clone() {
128        let mut this_ref_type = reference.qualification();
129        if this_ref_type == "qualified"
130            && is_struct_dialect
131            && &reference
132                .iter_raw_references()
133                .into_iter()
134                .next()
135                .unwrap()
136                .part
137                != table_ref_str
138        {
139            this_ref_type = "unqualified";
140        }
141
142        let lint_res = validate_one_reference(
143            tables,
144            single_table_references,
145            reference,
146            this_ref_type,
147            &standalone_aliases,
148            table_ref_str,
149            table_ref_str_source.clone(),
150            &col_alias_names,
151            &seen_ref_types,
152            fixable,
153        );
154
155        seen_ref_types.insert(this_ref_type);
156        let Some(lint_res) = lint_res else {
157            continue;
158        };
159
160        if let Some(fix_inconsistent_to) = fix_inconsistent_to
161            .as_ref()
162            .filter(|_| single_table_references == "consistent")
163        {
164            let results = check_references(
165                tables,
166                table_aliases.clone(),
167                standalone_aliases.clone(),
168                references.clone(),
169                col_aliases.clone(),
170                fix_inconsistent_to,
171                is_struct_dialect,
172                None,
173                fixable,
174            );
175
176            acc.extend(results);
177        }
178
179        acc.push(lint_res);
180    }
181
182    acc
183}
184
185#[allow(clippy::too_many_arguments)]
186fn validate_one_reference(
187    tables: &Tables,
188    single_table_references: &str,
189    ref_: ObjectReferenceSegment,
190    this_ref_type: &str,
191    standalone_aliases: &[SmolStr],
192    table_ref_str: &str,
193    _table_ref_str_source: Option<ErasedSegment>,
194    col_alias_names: &[SmolStr],
195    seen_ref_types: &HashSet<&str>,
196    fixable: bool,
197) -> Option<LintResult> {
198    if !ref_.is_qualified() && ref_.0.is_type(SyntaxKind::WildcardIdentifier) {
199        return None;
200    }
201
202    if ref_.0.is_templated() {
203        return None;
204    }
205
206    if standalone_aliases.contains(ref_.0.raw()) {
207        return None;
208    }
209
210    if table_ref_str.is_empty() {
211        return None;
212    }
213
214    if col_alias_names.contains(ref_.0.raw()) {
215        return None;
216    }
217
218    if single_table_references == "consistent" {
219        return if !seen_ref_types.is_empty() && !seen_ref_types.contains(this_ref_type) {
220            LintResult::new(
221                ref_.clone().0.into(),
222                Vec::new(),
223                format!(
224                    "{} reference '{}' found in single table select which is inconsistent with \
225                     previous references.",
226                    capitalize(this_ref_type),
227                    ref_.0.raw()
228                )
229                .into(),
230                None,
231            )
232            .into()
233        } else {
234            None
235        };
236    }
237
238    if single_table_references == this_ref_type {
239        return None;
240    }
241
242    if single_table_references == "unqualified" {
243        let fixes = if fixable {
244            ref_.0
245                .segments()
246                .iter()
247                .take(2)
248                .cloned()
249                .map(LintFix::delete)
250                .collect::<Vec<_>>()
251        } else {
252            Vec::new()
253        };
254
255        return LintResult::new(
256            ref_.0.clone().into(),
257            fixes,
258            format!(
259                "{} reference '{}' found in single table select.",
260                capitalize(this_ref_type),
261                ref_.0.raw()
262            )
263            .into(),
264            None,
265        )
266        .into();
267    }
268
269    let ref_ = ref_.0.clone();
270    let fixes = if fixable {
271        vec![LintFix::create_before(
272            if !ref_.segments().is_empty() {
273                ref_.segments()[0].clone()
274            } else {
275                ref_.clone()
276            },
277            vec![
278                SegmentBuilder::token(tables.next_id(), table_ref_str, SyntaxKind::NakedIdentifier)
279                    .finish(),
280                SegmentBuilder::symbol(tables.next_id(), "."),
281            ],
282        )]
283    } else {
284        Vec::new()
285    };
286
287    LintResult::new(
288        ref_.clone().into(),
289        fixes,
290        format!(
291            "{} reference '{}' found in single table select.",
292            capitalize(this_ref_type),
293            ref_.raw()
294        )
295        .into(),
296        None,
297    )
298    .into()
299}
300
301impl Rule for RuleRF03 {
302    fn load_from_config(&self, config: &HashMap<String, Value>) -> Result<ErasedRule, String> {
303        Ok(RuleRF03 {
304            single_table_references: config
305                .get("single_table_references")
306                .and_then(|it| it.as_string().map(ToString::to_string)),
307            force_enable: config["force_enable"].as_bool().unwrap(),
308        }
309        .erased())
310    }
311
312    fn name(&self) -> &'static str {
313        "references.consistent"
314    }
315
316    fn description(&self) -> &'static str {
317        "References should be consistent in statements with a single table."
318    }
319
320    fn long_description(&self) -> &'static str {
321        r#"
322**Anti-pattern**
323
324In this example, only the field b is referenced.
325
326```sql
327SELECT
328    a,
329    foo.b
330FROM foo
331```
332
333**Best practice**
334
335Add or remove references to all fields.
336
337```sql
338SELECT
339    a,
340    b
341FROM foo
342
343-- Also good
344
345SELECT
346    foo.a,
347    foo.b
348FROM foo
349```
350"#
351    }
352
353    fn groups(&self) -> &'static [RuleGroups] {
354        &[RuleGroups::All, RuleGroups::References]
355    }
356
357    fn force_enable(&self) -> bool {
358        self.force_enable
359    }
360
361    fn dialect_skip(&self) -> &'static [DialectKind] {
362        &[DialectKind::Bigquery, DialectKind::Redshift]
363    }
364
365    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
366        let single_table_references =
367            self.single_table_references.as_deref().unwrap_or_else(|| {
368                context.config.raw["rules"]["single_table_references"]
369                    .as_string()
370                    .unwrap()
371            });
372
373        let query: Query<'_> = Query::from_segment(&context.segment, context.dialect, None);
374        let mut visited: HashSet<ErasedSegment> = HashSet::new();
375        let is_struct_dialect = self.dialect_skip().contains(&context.dialect.name);
376
377        Self::visit_queries(
378            context.tables,
379            single_table_references,
380            is_struct_dialect,
381            query,
382            &mut visited,
383        )
384    }
385
386    fn is_fix_compatible(&self) -> bool {
387        true
388    }
389
390    fn crawl_behaviour(&self) -> Crawler {
391        SegmentSeekerCrawler::new(
392            const {
393                SyntaxSet::new(&[
394                    SyntaxKind::SelectStatement,
395                    SyntaxKind::SetExpression,
396                    SyntaxKind::WithCompoundStatement,
397                ])
398            },
399        )
400        .disallow_recurse()
401        .into()
402    }
403}