sqruff_lib/rules/structure/
st02.rs

1use ahash::{AHashMap, AHashSet};
2use itertools::{Itertools, chain};
3use smol_str::StrExt;
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder};
7use sqruff_lib_core::utils::functional::segments::Segments;
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Default, Debug, Clone)]
16pub struct RuleST02;
17
18impl Rule for RuleST02 {
19    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
20        Ok(RuleST02.erased())
21    }
22
23    fn name(&self) -> &'static str {
24        "structure.simple_case"
25    }
26
27    fn description(&self) -> &'static str {
28        "Unnecessary 'CASE' statement."
29    }
30
31    fn long_description(&self) -> &'static str {
32        r#"
33**Anti-pattern**
34
35CASE statement returns booleans.
36
37```sql
38select
39    case
40        when fab > 0 then true
41        else false
42    end as is_fab
43from fancy_table
44
45-- This rule can also simplify CASE statements
46-- that aim to fill NULL values.
47
48select
49    case
50        when fab is null then 0
51        else fab
52    end as fab_clean
53from fancy_table
54
55-- This also covers where the case statement
56-- replaces NULL values with NULL values.
57
58select
59    case
60        when fab is null then null
61        else fab
62    end as fab_clean
63from fancy_table
64```
65
66**Best practice**
67
68Reduce to WHEN condition within COALESCE function.
69
70```sql
71select
72    coalesce(fab > 0, false) as is_fab
73from fancy_table
74
75-- To fill NULL values.
76
77select
78    coalesce(fab, 0) as fab_clean
79from fancy_table
80
81-- NULL filling NULL.
82
83select fab as fab_clean
84from fancy_table
85```
86"#
87    }
88
89    fn groups(&self) -> &'static [RuleGroups] {
90        &[RuleGroups::All, RuleGroups::Structure]
91    }
92
93    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
94        if context.segment.segments()[0]
95            .raw()
96            .eq_ignore_ascii_case("CASE")
97        {
98            let children = FunctionalContext::new(context).segment().children(None);
99
100            let when_clauses = children.select(
101                Some(|it: &ErasedSegment| it.is_type(SyntaxKind::WhenClause)),
102                None,
103                None,
104                None,
105            );
106            let else_clauses = children.select(
107                Some(|it: &ErasedSegment| it.is_type(SyntaxKind::ElseClause)),
108                None,
109                None,
110                None,
111            );
112
113            if when_clauses.len() > 1 {
114                return Vec::new();
115            }
116
117            let condition_expression =
118                when_clauses.children(Some(|it| it.is_type(SyntaxKind::Expression)))[0].clone();
119            let then_expression =
120                when_clauses.children(Some(|it| it.is_type(SyntaxKind::Expression)))[1].clone();
121
122            if !else_clauses.is_empty()
123                && let Some(else_expression) = else_clauses
124                    .children(Some(|it| it.is_type(SyntaxKind::Expression)))
125                    .first()
126            {
127                let upper_bools = ["TRUE", "FALSE"];
128
129                let then_expression_upper = then_expression.raw().to_uppercase_smolstr();
130                let else_expression_upper = else_expression.raw().to_uppercase_smolstr();
131
132                if upper_bools.contains(&then_expression_upper.as_str())
133                    && upper_bools.contains(&else_expression_upper.as_str())
134                    && then_expression_upper != else_expression_upper
135                {
136                    let coalesce_arg_1 = condition_expression.clone();
137                    let coalesce_arg_2 = SegmentBuilder::keyword(context.tables.next_id(), "false");
138                    let preceding_not = then_expression_upper == "FALSE";
139
140                    let fixes = Self::coalesce_fix_list(
141                        context,
142                        coalesce_arg_1,
143                        coalesce_arg_2,
144                        preceding_not,
145                    );
146
147                    return vec![LintResult::new(
148                        condition_expression.into(),
149                        fixes,
150                        "Unnecessary CASE statement. Use COALESCE function instead."
151                            .to_owned()
152                            .into(),
153                        None,
154                    )];
155                }
156            }
157
158            let condition_expression_segments_raw: AHashSet<_> = AHashSet::from_iter(
159                condition_expression
160                    .segments()
161                    .iter()
162                    .map(|segment| segment.raw().to_uppercase_smolstr()),
163            );
164
165            if condition_expression_segments_raw.contains("IS")
166                && condition_expression_segments_raw.contains("NULL")
167                && condition_expression_segments_raw
168                    .intersection(&AHashSet::from_iter(["AND".into(), "OR".into()]))
169                    .next()
170                    .is_none()
171            {
172                let is_not_prefix = condition_expression_segments_raw.contains("NOT");
173
174                let tmp = Segments::new(condition_expression.clone(), None)
175                    .children(Some(|it| it.is_type(SyntaxKind::ColumnReference)));
176
177                let Some(column_reference_segment) = tmp.first() else {
178                    return Vec::new();
179                };
180
181                let array_accessor_segment = Segments::new(condition_expression.clone(), None)
182                    .children(Some(|it: &ErasedSegment| {
183                        it.is_type(SyntaxKind::ArrayAccessor)
184                    }))
185                    .first()
186                    .cloned();
187
188                let column_reference_segment_raw_upper = match array_accessor_segment {
189                    Some(array_accessor_segment) => {
190                        column_reference_segment.raw().to_lowercase()
191                            + &array_accessor_segment.raw().to_uppercase()
192                    }
193                    None => column_reference_segment.raw().to_uppercase(),
194                };
195
196                if !else_clauses.is_empty() {
197                    let else_expression = else_clauses
198                        .children(Some(|it| it.is_type(SyntaxKind::Expression)))[0]
199                        .clone();
200
201                    let (coalesce_arg_1, coalesce_arg_2) = if !is_not_prefix
202                        && column_reference_segment_raw_upper
203                            == else_expression.raw().to_uppercase_smolstr()
204                    {
205                        (else_expression, then_expression)
206                    } else if is_not_prefix
207                        && column_reference_segment_raw_upper
208                            == then_expression.raw().to_uppercase_smolstr()
209                    {
210                        (then_expression, else_expression)
211                    } else {
212                        return Vec::new();
213                    };
214
215                    if coalesce_arg_2.raw().eq_ignore_ascii_case("NULL") {
216                        let fixes =
217                            Self::column_only_fix_list(context, column_reference_segment.clone());
218                        return vec![LintResult::new(
219                            condition_expression.into(),
220                            fixes,
221                            Some(String::new()),
222                            None,
223                        )];
224                    }
225
226                    let fixes =
227                        Self::coalesce_fix_list(context, coalesce_arg_1, coalesce_arg_2, false);
228
229                    return vec![LintResult::new(
230                        condition_expression.into(),
231                        fixes,
232                        "Unnecessary CASE statement. Use COALESCE function instead."
233                            .to_owned()
234                            .into(),
235                        None,
236                    )];
237                } else if column_reference_segment
238                    .raw()
239                    .eq_ignore_ascii_case(then_expression.raw())
240                {
241                    let fixes =
242                        Self::column_only_fix_list(context, column_reference_segment.clone());
243
244                    return vec![LintResult::new(
245                        condition_expression.into(),
246                        fixes,
247                        format!(
248                            "Unnecessary CASE statement. Just use column '{}'.",
249                            column_reference_segment.raw()
250                        )
251                        .into(),
252                        None,
253                    )];
254                }
255            }
256
257            Vec::new()
258        } else {
259            Vec::new()
260        }
261    }
262
263    fn crawl_behaviour(&self) -> Crawler {
264        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::CaseExpression]) }).into()
265    }
266}
267
268impl RuleST02 {
269    fn coalesce_fix_list(
270        context: &RuleContext,
271        coalesce_arg_1: ErasedSegment,
272        coalesce_arg_2: ErasedSegment,
273        preceding_not: bool,
274    ) -> Vec<LintFix> {
275        let mut edits = vec![
276            SegmentBuilder::token(
277                context.tables.next_id(),
278                "coalesce",
279                SyntaxKind::FunctionNameIdentifier,
280            )
281            .finish(),
282            SegmentBuilder::symbol(context.tables.next_id(), "("),
283            coalesce_arg_1,
284            SegmentBuilder::symbol(context.tables.next_id(), ","),
285            SegmentBuilder::whitespace(context.tables.next_id(), " "),
286            coalesce_arg_2,
287            SegmentBuilder::symbol(context.tables.next_id(), ")"),
288        ];
289
290        if preceding_not {
291            edits = chain(
292                [
293                    SegmentBuilder::keyword(context.tables.next_id(), "not"),
294                    SegmentBuilder::whitespace(context.tables.next_id(), " "),
295                ],
296                edits,
297            )
298            .collect_vec();
299        }
300
301        vec![LintFix::replace(context.segment.clone(), edits, None)]
302    }
303
304    fn column_only_fix_list(
305        context: &RuleContext,
306        column_reference_segment: ErasedSegment,
307    ) -> Vec<LintFix> {
308        vec![LintFix::replace(
309            context.segment.clone(),
310            vec![column_reference_segment],
311            None,
312        )]
313    }
314}