sqruff_lib/rules/convention/
cv04.rs

1use ahash::AHashMap;
2use itertools::Itertools;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder};
6
7use crate::core::config::Value;
8use crate::core::rules::context::RuleContext;
9use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
10use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
11use crate::utils::functional::context::FunctionalContext;
12
13#[derive(Debug, Default, Clone)]
14pub struct RuleCV04 {
15    pub prefer_count_1: bool,
16    pub prefer_count_0: bool,
17}
18
19impl Rule for RuleCV04 {
20    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
21        Ok(RuleCV04 {
22            prefer_count_1: _config
23                .get("prefer_count_1")
24                .unwrap_or(&Value::Bool(false))
25                .as_bool()
26                .unwrap(),
27            prefer_count_0: _config
28                .get("prefer_count_0")
29                .unwrap_or(&Value::Bool(false))
30                .as_bool()
31                .unwrap(),
32        }
33        .erased())
34    }
35
36    fn name(&self) -> &'static str {
37        "convention.count_rows"
38    }
39
40    fn description(&self) -> &'static str {
41        "Use consistent syntax to express \"count number of rows\"."
42    }
43
44    fn long_description(&self) -> &'static str {
45        r#"
46**Anti-pattern**
47
48In this example, `count(1)` is used to count the number of rows in a table.
49
50```sql
51select
52    count(1)
53from table_a
54```
55
56**Best practice**
57
58Use count(*) unless specified otherwise by config prefer_count_1, or prefer_count_0 as preferred.
59
60```sql
61select
62    count(*)
63from table_a
64```
65"#
66    }
67
68    fn groups(&self) -> &'static [RuleGroups] {
69        &[RuleGroups::All, RuleGroups::Core, RuleGroups::Convention]
70    }
71
72    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
73        let Some(function_name) = context
74            .segment
75            .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
76        else {
77            return Vec::new();
78        };
79
80        if function_name.raw().eq_ignore_ascii_case("COUNT") {
81            let f_content = FunctionalContext::new(context)
82                .segment()
83                .children(Some(|it: &ErasedSegment| {
84                    it.is_type(SyntaxKind::FunctionContents)
85                }))
86                .children(Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)))
87                .children(Some(|it: &ErasedSegment| {
88                    !it.is_meta()
89                        && !matches!(
90                            it.get_type(),
91                            SyntaxKind::StartBracket
92                                | SyntaxKind::EndBracket
93                                | SyntaxKind::Whitespace
94                                | SyntaxKind::Newline
95                        )
96                }));
97
98            if f_content.len() != 1 {
99                return Vec::new();
100            }
101
102            let preferred = if self.prefer_count_1 {
103                "1"
104            } else if self.prefer_count_0 {
105                "0"
106            } else {
107                "*"
108            };
109
110            if f_content[0].is_type(SyntaxKind::Star)
111                && (self.prefer_count_0 || self.prefer_count_1)
112            {
113                let new_segment =
114                    SegmentBuilder::token(context.tables.next_id(), preferred, SyntaxKind::Literal)
115                        .finish();
116                return vec![LintResult::new(
117                    context.segment.clone().into(),
118                    vec![LintFix::replace(
119                        f_content[0].clone(),
120                        vec![new_segment],
121                        None,
122                    )],
123                    None,
124                    None,
125                )];
126            }
127
128            if f_content[0].is_type(SyntaxKind::Expression) {
129                let expression_content = f_content[0]
130                    .segments()
131                    .iter()
132                    .filter(|it| !it.is_meta())
133                    .collect_vec();
134
135                let raw = expression_content[0].raw();
136                if expression_content.len() == 1
137                    && matches!(
138                        expression_content[0].get_type(),
139                        SyntaxKind::NumericLiteral | SyntaxKind::Literal
140                    )
141                    && (raw == "0" || raw == "1")
142                    && raw != preferred
143                {
144                    let first_expression = expression_content[0].clone();
145                    let first_expression_raw = first_expression.raw();
146
147                    return vec![LintResult::new(
148                        context.segment.clone().into(),
149                        vec![LintFix::replace(
150                            first_expression.clone(),
151                            vec![
152                                first_expression.edit(
153                                    context.tables.next_id(),
154                                    first_expression
155                                        .raw()
156                                        .replace(first_expression_raw.as_str(), preferred)
157                                        .into(),
158                                    None,
159                                ),
160                            ],
161                            None,
162                        )],
163                        None,
164                        None,
165                    )];
166                }
167            }
168        }
169
170        Vec::new()
171    }
172
173    fn crawl_behaviour(&self) -> Crawler {
174        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::Function]) }).into()
175    }
176}