sqruff_lib/core/
rules.rs

1pub mod context;
2pub mod crawlers;
3pub mod noqa;
4pub mod reference;
5
6use std::fmt::{self, Debug};
7use std::ops::Deref;
8
9use std::sync::Arc;
10
11use ahash::{AHashMap, AHashSet};
12use itertools::chain;
13use sqruff_lib_core::dialects::Dialect;
14use sqruff_lib_core::dialects::init::DialectKind;
15use sqruff_lib_core::errors::{ErrorStructRule, SQLLintError};
16use sqruff_lib_core::helpers::{Config, IndexMap};
17use sqruff_lib_core::lint_fix::LintFix;
18use sqruff_lib_core::parser::segments::{ErasedSegment, Tables};
19use sqruff_lib_core::templaters::TemplatedFile;
20use strum_macros::AsRefStr;
21
22use crate::core::config::{FluffConfig, Value};
23use crate::core::rules::context::RuleContext;
24use crate::core::rules::crawlers::{BaseCrawler as _, Crawler};
25
26pub struct LintResult {
27    pub anchor: Option<ErasedSegment>,
28    pub fixes: Vec<LintFix>,
29    description: Option<String>,
30    source: String,
31}
32
33#[derive(Debug, Clone, PartialEq, Copy, Hash, Eq, AsRefStr)]
34#[strum(serialize_all = "lowercase")]
35pub enum RuleGroups {
36    All,
37    Core,
38    Aliasing,
39    Ambiguous,
40    Capitalisation,
41    Convention,
42    Layout,
43    References,
44    Structure,
45}
46
47impl LintResult {
48    pub fn new(
49        anchor: Option<ErasedSegment>,
50        fixes: Vec<LintFix>,
51        description: Option<String>,
52        source: Option<String>,
53    ) -> Self {
54        // let fixes = fixes.into_iter().filter(|f| !f.is_trivial()).collect();
55
56        LintResult {
57            anchor,
58            fixes,
59            description,
60            source: source.unwrap_or_default(),
61        }
62    }
63
64    pub fn to_linting_error(self, rule: &ErasedRule) -> Option<SQLLintError> {
65        let anchor = self.anchor.clone()?;
66
67        let description = self
68            .description
69            .as_deref()
70            .unwrap_or_else(|| rule.description());
71
72        let is_fixable = rule.is_fix_compatible();
73
74        SQLLintError::new(description, anchor, is_fixable)
75            .config(|this| {
76                this.rule = Some(ErrorStructRule {
77                    name: rule.name(),
78                    code: rule.code(),
79                })
80            })
81            .into()
82    }
83}
84
85impl Debug for LintResult {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        match &self.anchor {
88            None => write!(f, "LintResult(<empty>)"),
89            Some(anchor) => {
90                let fix_coda = if !self.fixes.is_empty() {
91                    format!("+{}F", self.fixes.len())
92                } else {
93                    "".to_string()
94                };
95
96                match &self.description {
97                    Some(desc) => {
98                        if !self.source.is_empty() {
99                            write!(
100                                f,
101                                "LintResult({} [{}]: {:?}{})",
102                                desc, self.source, anchor, fix_coda
103                            )
104                        } else {
105                            write!(f, "LintResult({desc}: {anchor:?}{fix_coda})")
106                        }
107                    }
108                    None => write!(f, "LintResult({anchor:?}{fix_coda})"),
109                }
110            }
111        }
112    }
113}
114
115#[derive(Debug, Clone, PartialEq)]
116pub enum LintPhase {
117    Main,
118    Post,
119}
120
121pub trait Rule: Debug + 'static + Send + Sync {
122    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String>;
123
124    fn lint_phase(&self) -> LintPhase {
125        LintPhase::Main
126    }
127
128    fn name(&self) -> &'static str;
129
130    fn config_ref(&self) -> &'static str {
131        self.name()
132    }
133
134    fn description(&self) -> &'static str;
135
136    fn long_description(&self) -> &'static str;
137
138    /// All the groups this rule belongs to, including 'all' because that is a
139    /// given. There should be no duplicates and 'all' should be the first
140    /// element.
141    fn groups(&self) -> &'static [RuleGroups];
142
143    fn force_enable(&self) -> bool {
144        false
145    }
146
147    /// Returns the set of dialects for which a particular rule should be
148    /// skipped.
149    fn dialect_skip(&self) -> &'static [DialectKind] {
150        &[]
151    }
152
153    fn code(&self) -> &'static str {
154        let name = std::any::type_name::<Self>();
155        name.split("::")
156            .last()
157            .unwrap()
158            .strip_prefix("Rule")
159            .unwrap_or(name)
160    }
161
162    fn eval(&self, rule_cx: &RuleContext) -> Vec<LintResult>;
163
164    fn is_fix_compatible(&self) -> bool {
165        false
166    }
167
168    fn crawl_behaviour(&self) -> Crawler;
169}
170
171pub struct Exception;
172
173pub fn crawl(
174    rule: &ErasedRule,
175    tables: &Tables,
176    dialect: &Dialect,
177    templated_file: &TemplatedFile,
178    tree: ErasedSegment,
179    config: &FluffConfig,
180    on_violation: &mut impl FnMut(LintResult),
181) -> Result<(), Exception> {
182    let mut root_context = RuleContext::new(tables, dialect, config, tree.clone());
183    let mut has_exception = false;
184
185    // TODO Will to return a note that rules were skipped
186    if rule.dialect_skip().contains(&dialect.name) && !rule.force_enable() {
187        return Ok(());
188    }
189
190    rule.crawl_behaviour()
191        .crawl(&mut root_context, &mut |context| {
192            let resp =
193                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| rule.eval(context)));
194
195            let Ok(results) = resp else {
196                has_exception = true;
197                return;
198            };
199
200            for result in results {
201                if !result
202                    .fixes
203                    .iter()
204                    .any(|it| it.has_template_conflicts(templated_file))
205                {
206                    on_violation(result);
207                }
208            }
209        });
210
211    if has_exception {
212        Err(Exception)
213    } else {
214        Ok(())
215    }
216}
217
218#[derive(Debug, Clone)]
219pub struct ErasedRule {
220    erased: Arc<dyn Rule>,
221}
222
223impl PartialEq for ErasedRule {
224    fn eq(&self, _other: &Self) -> bool {
225        unimplemented!()
226    }
227}
228
229impl Deref for ErasedRule {
230    type Target = dyn Rule;
231
232    fn deref(&self) -> &Self::Target {
233        self.erased.as_ref()
234    }
235}
236
237pub trait Erased {
238    type Erased;
239
240    fn erased(self) -> Self::Erased;
241}
242
243impl<T: Rule> Erased for T {
244    type Erased = ErasedRule;
245
246    fn erased(self) -> Self::Erased {
247        ErasedRule {
248            erased: Arc::new(self),
249        }
250    }
251}
252
253pub struct RuleManifest {
254    pub code: &'static str,
255    pub name: &'static str,
256    pub description: &'static str,
257    pub groups: &'static [RuleGroups],
258    pub rule_class: ErasedRule,
259}
260
261#[derive(Clone)]
262pub struct RulePack {
263    pub(crate) rules: Vec<ErasedRule>,
264    _reference_map: AHashMap<&'static str, AHashSet<&'static str>>,
265}
266
267impl RulePack {
268    pub fn rules(&self) -> Vec<ErasedRule> {
269        self.rules.clone()
270    }
271}
272
273pub struct RuleSet {
274    pub(crate) register: IndexMap<&'static str, RuleManifest>,
275}
276
277impl RuleSet {
278    fn rule_reference_map(&self) -> AHashMap<&'static str, AHashSet<&'static str>> {
279        let valid_codes: AHashSet<_> = self.register.keys().copied().collect();
280
281        let reference_map: AHashMap<_, AHashSet<_>> = valid_codes
282            .iter()
283            .map(|&code| (code, AHashSet::from([code])))
284            .collect();
285
286        let name_map = {
287            let mut name_map = AHashMap::new();
288            for manifest in self.register.values() {
289                name_map
290                    .entry(manifest.name)
291                    .or_insert_with(AHashSet::new)
292                    .insert(manifest.code);
293            }
294            name_map
295        };
296
297        let name_collisions: AHashSet<_> = {
298            let name_keys: AHashSet<_> = name_map.keys().copied().collect();
299            name_keys.intersection(&valid_codes).copied().collect()
300        };
301
302        if !name_collisions.is_empty() {
303            log::warn!(
304                "The following defined rule names were found which collide with codes. Those \
305                 names will not be available for selection: {name_collisions:?}",
306            );
307        }
308
309        let reference_map: AHashMap<_, _> = chain(name_map, reference_map).collect();
310
311        let mut group_map: AHashMap<_, AHashSet<&'static str>> = AHashMap::new();
312        for manifest in self.register.values() {
313            for group in manifest.groups {
314                let group = group.as_ref();
315                if let Some(codes) = reference_map.get(group) {
316                    log::warn!(
317                        "Rule {} defines group '{}' which is already defined as a name or code of \
318                         {:?}. This group will not be available for use as a result of this \
319                         collision.",
320                        manifest.code,
321                        group,
322                        codes
323                    );
324                } else {
325                    group_map
326                        .entry(group)
327                        .or_insert_with(AHashSet::new)
328                        .insert(manifest.code);
329                }
330            }
331        }
332
333        chain(group_map, reference_map).collect()
334    }
335
336    fn expand_rule_refs(
337        &self,
338        glob_list: Vec<String>,
339        reference_map: &AHashMap<&'static str, AHashSet<&'static str>>,
340    ) -> AHashSet<&'static str> {
341        let mut expanded_rule_set = AHashSet::new();
342
343        for r in glob_list {
344            if reference_map.contains_key(r.as_str()) {
345                expanded_rule_set.extend(reference_map[r.as_str()].clone());
346            } else {
347                panic!("Rule {r} not found in rule reference map");
348            }
349        }
350
351        expanded_rule_set
352    }
353
354    pub(crate) fn get_rulepack(&self, config: &FluffConfig) -> RulePack {
355        let reference_map = self.rule_reference_map();
356        let rules = config.get_section("rules");
357        let keylist = self.register.keys();
358        let mut instantiated_rules = Vec::with_capacity(keylist.len());
359
360        let allowlist: Vec<String> = match config.get("rule_allowlist", "core").as_array() {
361            Some(array) => array
362                .iter()
363                .map(|it| it.as_string().unwrap().to_owned())
364                .collect(),
365            None => self.register.keys().map(|it| it.to_string()).collect(),
366        };
367
368        let denylist: Vec<String> = match config.get("rule_denylist", "core").as_array() {
369            Some(array) => array
370                .iter()
371                .map(|it| it.as_string().unwrap().to_owned())
372                .collect(),
373            None => Vec::new(),
374        };
375
376        let expanded_allowlist = self.expand_rule_refs(allowlist, &reference_map);
377        let expanded_denylist = self.expand_rule_refs(denylist, &reference_map);
378
379        let keylist: Vec<_> = keylist
380            .into_iter()
381            .filter(|&&r| expanded_allowlist.contains(r) && !expanded_denylist.contains(r))
382            .collect();
383
384        for code in keylist {
385            let rule = self.register[code].rule_class.clone();
386            let rule_config_ref = rule.config_ref();
387
388            let tmp = AHashMap::new();
389
390            let specific_rule_config = rules
391                .get(rule_config_ref)
392                .and_then(|section| section.as_map())
393                .unwrap_or(&tmp);
394
395            // TODO fail the rulepack if any need unwrapping
396            instantiated_rules.push(rule.load_from_config(specific_rule_config).unwrap());
397        }
398
399        RulePack {
400            rules: instantiated_rules,
401            _reference_map: reference_map,
402        }
403    }
404}