sqruff_lib/rules/convention/
cv04.rs1use ahash::AHashMap;
2use itertools::Itertools;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder};
6
7use crate::core::config::Value;
8use crate::core::rules::context::RuleContext;
9use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
10use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
11use crate::utils::functional::context::FunctionalContext;
12
13#[derive(Debug, Default, Clone)]
14pub struct RuleCV04 {
15 pub prefer_count_1: bool,
16 pub prefer_count_0: bool,
17}
18
19impl Rule for RuleCV04 {
20 fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
21 Ok(RuleCV04 {
22 prefer_count_1: _config
23 .get("prefer_count_1")
24 .unwrap_or(&Value::Bool(false))
25 .as_bool()
26 .unwrap(),
27 prefer_count_0: _config
28 .get("prefer_count_0")
29 .unwrap_or(&Value::Bool(false))
30 .as_bool()
31 .unwrap(),
32 }
33 .erased())
34 }
35
36 fn name(&self) -> &'static str {
37 "convention.count_rows"
38 }
39
40 fn description(&self) -> &'static str {
41 "Use consistent syntax to express \"count number of rows\"."
42 }
43
44 fn long_description(&self) -> &'static str {
45 r#"
46**Anti-pattern**
47
48In this example, `count(1)` is used to count the number of rows in a table.
49
50```sql
51select
52 count(1)
53from table_a
54```
55
56**Best practice**
57
58Use count(*) unless specified otherwise by config prefer_count_1, or prefer_count_0 as preferred.
59
60```sql
61select
62 count(*)
63from table_a
64```
65"#
66 }
67
68 fn groups(&self) -> &'static [RuleGroups] {
69 &[RuleGroups::All, RuleGroups::Core, RuleGroups::Convention]
70 }
71
72 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
73 let Some(function_name) = context
74 .segment
75 .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
76 else {
77 return Vec::new();
78 };
79
80 if function_name.raw().eq_ignore_ascii_case("COUNT") {
81 let f_content = FunctionalContext::new(context)
82 .segment()
83 .children(Some(|it: &ErasedSegment| {
84 it.is_type(SyntaxKind::FunctionContents)
85 }))
86 .children(Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)))
87 .children(Some(|it: &ErasedSegment| {
88 !it.is_meta()
89 && !matches!(
90 it.get_type(),
91 SyntaxKind::StartBracket
92 | SyntaxKind::EndBracket
93 | SyntaxKind::Whitespace
94 | SyntaxKind::Newline
95 )
96 }));
97
98 if f_content.len() != 1 {
99 return Vec::new();
100 }
101
102 let preferred = if self.prefer_count_1 {
103 "1"
104 } else if self.prefer_count_0 {
105 "0"
106 } else {
107 "*"
108 };
109
110 if f_content[0].is_type(SyntaxKind::Star)
111 && (self.prefer_count_0 || self.prefer_count_1)
112 {
113 let new_segment =
114 SegmentBuilder::token(context.tables.next_id(), preferred, SyntaxKind::Literal)
115 .finish();
116 return vec![LintResult::new(
117 context.segment.clone().into(),
118 vec![LintFix::replace(
119 f_content[0].clone(),
120 vec![new_segment],
121 None,
122 )],
123 None,
124 None,
125 )];
126 }
127
128 if f_content[0].is_type(SyntaxKind::Expression) {
129 let expression_content = f_content[0]
130 .segments()
131 .iter()
132 .filter(|it| !it.is_meta())
133 .collect_vec();
134
135 let raw = expression_content[0].raw();
136 if expression_content.len() == 1
137 && matches!(
138 expression_content[0].get_type(),
139 SyntaxKind::NumericLiteral | SyntaxKind::Literal
140 )
141 && (raw == "0" || raw == "1")
142 && raw != preferred
143 {
144 let first_expression = expression_content[0].clone();
145 let first_expression_raw = first_expression.raw();
146
147 return vec![LintResult::new(
148 context.segment.clone().into(),
149 vec![LintFix::replace(
150 first_expression.clone(),
151 vec![
152 first_expression.edit(
153 context.tables.next_id(),
154 first_expression
155 .raw()
156 .replace(first_expression_raw.as_str(), preferred)
157 .into(),
158 None,
159 ),
160 ],
161 None,
162 )],
163 None,
164 None,
165 )];
166 }
167 }
168 }
169
170 Vec::new()
171 }
172
173 fn crawl_behaviour(&self) -> Crawler {
174 SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::Function]) }).into()
175 }
176}