sqruff_lib/rules/convention/
cv04.rs1use ahash::AHashMap;
2use itertools::Itertools;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder};
6
7use crate::core::config::Value;
8use crate::core::rules::context::RuleContext;
9use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
10use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
11use crate::utils::functional::context::FunctionalContext;
12
13#[derive(Debug, Default, Clone)]
14pub struct RuleCV04 {
15 pub prefer_count_1: bool,
16 pub prefer_count_0: bool,
17}
18
19impl Rule for RuleCV04 {
20 fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
21 Ok(RuleCV04 {
22 prefer_count_1: _config
23 .get("prefer_count_1")
24 .unwrap_or(&Value::Bool(false))
25 .as_bool()
26 .unwrap(),
27 prefer_count_0: _config
28 .get("prefer_count_0")
29 .unwrap_or(&Value::Bool(false))
30 .as_bool()
31 .unwrap(),
32 }
33 .erased())
34 }
35
36 fn name(&self) -> &'static str {
37 "convention.count_rows"
38 }
39
40 fn description(&self) -> &'static str {
41 "Use consistent syntax to express \"count number of rows\"."
42 }
43
44 fn long_description(&self) -> &'static str {
45 r#"
46**Anti-pattern**
47
48In this example, `count(1)` is used to count the number of rows in a table.
49
50```sql
51select
52 count(1)
53from table_a
54```
55
56**Best practice**
57
58Use count(*) unless specified otherwise by config prefer_count_1, or prefer_count_0 as preferred.
59
60```sql
61select
62 count(*)
63from table_a
64```
65"#
66 }
67
68 fn groups(&self) -> &'static [RuleGroups] {
69 &[RuleGroups::All, RuleGroups::Core, RuleGroups::Convention]
70 }
71
72 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
73 let Some(function_name) = context
74 .segment
75 .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
76 else {
77 return Vec::new();
78 };
79
80 if function_name.raw().eq_ignore_ascii_case("COUNT") {
81 let f_content = FunctionalContext::new(context)
82 .segment()
83 .children_where(|it: &ErasedSegment| it.is_type(SyntaxKind::FunctionContents))
84 .children_where(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed))
85 .children_where(|it: &ErasedSegment| {
86 !it.is_meta()
87 && !matches!(
88 it.get_type(),
89 SyntaxKind::StartBracket
90 | SyntaxKind::EndBracket
91 | SyntaxKind::Whitespace
92 | SyntaxKind::Newline
93 )
94 });
95
96 if f_content.len() != 1 {
97 return Vec::new();
98 }
99
100 let preferred = if self.prefer_count_1 {
101 "1"
102 } else if self.prefer_count_0 {
103 "0"
104 } else {
105 "*"
106 };
107
108 if f_content[0].is_type(SyntaxKind::Star)
109 && (self.prefer_count_0 || self.prefer_count_1)
110 {
111 let new_segment =
112 SegmentBuilder::token(context.tables.next_id(), preferred, SyntaxKind::Literal)
113 .finish();
114 return vec![LintResult::new(
115 context.segment.clone().into(),
116 vec![LintFix::replace(
117 f_content[0].clone(),
118 vec![new_segment],
119 None,
120 )],
121 None,
122 None,
123 )];
124 }
125
126 if f_content[0].is_type(SyntaxKind::Expression) {
127 let expression_content = f_content[0]
128 .segments()
129 .iter()
130 .filter(|it| !it.is_meta())
131 .collect_vec();
132
133 let raw = expression_content[0].raw();
134 if expression_content.len() == 1
135 && matches!(
136 expression_content[0].get_type(),
137 SyntaxKind::NumericLiteral | SyntaxKind::Literal
138 )
139 && (raw == "0" || raw == "1")
140 && raw != preferred
141 {
142 let first_expression = expression_content[0].clone();
143 let first_expression_raw = first_expression.raw();
144
145 return vec![LintResult::new(
146 context.segment.clone().into(),
147 vec![LintFix::replace(
148 first_expression.clone(),
149 vec![
150 first_expression.edit(
151 context.tables.next_id(),
152 first_expression
153 .raw()
154 .replace(first_expression_raw.as_str(), preferred)
155 .into(),
156 None,
157 ),
158 ],
159 None,
160 )],
161 None,
162 None,
163 )];
164 }
165 }
166 }
167
168 Vec::new()
169 }
170
171 fn crawl_behaviour(&self) -> Crawler {
172 SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::Function]) }).into()
173 }
174}