sqruff_lib/rules/structure/
st02.rs

1use ahash::{AHashMap, AHashSet};
2use itertools::{Itertools, chain};
3use smol_str::StrExt;
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder};
7use sqruff_lib_core::utils::functional::segments::Segments;
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Default, Debug, Clone)]
16pub struct RuleST02;
17
18impl Rule for RuleST02 {
19    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
20        Ok(RuleST02.erased())
21    }
22
23    fn name(&self) -> &'static str {
24        "structure.simple_case"
25    }
26
27    fn description(&self) -> &'static str {
28        "Unnecessary 'CASE' statement."
29    }
30
31    fn long_description(&self) -> &'static str {
32        r#"
33**Anti-pattern**
34
35CASE statement returns booleans.
36
37```sql
38select
39    case
40        when fab > 0 then true
41        else false
42    end as is_fab
43from fancy_table
44
45-- This rule can also simplify CASE statements
46-- that aim to fill NULL values.
47
48select
49    case
50        when fab is null then 0
51        else fab
52    end as fab_clean
53from fancy_table
54
55-- This also covers where the case statement
56-- replaces NULL values with NULL values.
57
58select
59    case
60        when fab is null then null
61        else fab
62    end as fab_clean
63from fancy_table
64```
65
66**Best practice**
67
68Reduce to WHEN condition within COALESCE function.
69
70```sql
71select
72    coalesce(fab > 0, false) as is_fab
73from fancy_table
74
75-- To fill NULL values.
76
77select
78    coalesce(fab, 0) as fab_clean
79from fancy_table
80
81-- NULL filling NULL.
82
83select fab as fab_clean
84from fancy_table
85```
86"#
87    }
88
89    fn groups(&self) -> &'static [RuleGroups] {
90        &[RuleGroups::All, RuleGroups::Structure]
91    }
92
93    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
94        if context.segment.segments()[0]
95            .raw()
96            .eq_ignore_ascii_case("CASE")
97        {
98            let children = FunctionalContext::new(context).segment().children_all();
99
100            let when_clauses =
101                children.filter(|it: &ErasedSegment| it.is_type(SyntaxKind::WhenClause));
102            let else_clauses =
103                children.filter(|it: &ErasedSegment| it.is_type(SyntaxKind::ElseClause));
104
105            if when_clauses.len() > 1 {
106                return Vec::new();
107            }
108
109            let condition_expression =
110                when_clauses.children_where(|it| it.is_type(SyntaxKind::Expression))[0].clone();
111            let then_expression =
112                when_clauses.children_where(|it| it.is_type(SyntaxKind::Expression))[1].clone();
113
114            if !else_clauses.is_empty()
115                && let Some(else_expression) = else_clauses
116                    .children_where(|it| it.is_type(SyntaxKind::Expression))
117                    .first()
118            {
119                let upper_bools = ["TRUE", "FALSE"];
120
121                let then_expression_upper = then_expression.raw().to_uppercase_smolstr();
122                let else_expression_upper = else_expression.raw().to_uppercase_smolstr();
123
124                if upper_bools.contains(&then_expression_upper.as_str())
125                    && upper_bools.contains(&else_expression_upper.as_str())
126                    && then_expression_upper != else_expression_upper
127                {
128                    let coalesce_arg_1 = condition_expression.clone();
129                    let coalesce_arg_2 = SegmentBuilder::keyword(context.tables.next_id(), "false");
130                    let preceding_not = then_expression_upper == "FALSE";
131
132                    let fixes = Self::coalesce_fix_list(
133                        context,
134                        coalesce_arg_1,
135                        coalesce_arg_2,
136                        preceding_not,
137                    );
138
139                    return vec![LintResult::new(
140                        condition_expression.into(),
141                        fixes,
142                        "Unnecessary CASE statement. Use COALESCE function instead."
143                            .to_owned()
144                            .into(),
145                        None,
146                    )];
147                }
148            }
149
150            let condition_expression_segments_raw: AHashSet<_> = AHashSet::from_iter(
151                condition_expression
152                    .segments()
153                    .iter()
154                    .map(|segment| segment.raw().to_uppercase_smolstr()),
155            );
156
157            if condition_expression_segments_raw.contains("IS")
158                && condition_expression_segments_raw.contains("NULL")
159                && condition_expression_segments_raw
160                    .intersection(&AHashSet::from_iter(["AND".into(), "OR".into()]))
161                    .next()
162                    .is_none()
163            {
164                let is_not_prefix = condition_expression_segments_raw.contains("NOT");
165
166                let tmp = Segments::new(condition_expression.clone(), None)
167                    .children_where(|it| it.is_type(SyntaxKind::ColumnReference));
168
169                let Some(column_reference_segment) = tmp.first() else {
170                    return Vec::new();
171                };
172
173                let array_accessor_segment = Segments::new(condition_expression.clone(), None)
174                    .children_where(|it: &ErasedSegment| it.is_type(SyntaxKind::ArrayAccessor))
175                    .first()
176                    .cloned();
177
178                let column_reference_segment_raw_upper = match array_accessor_segment {
179                    Some(array_accessor_segment) => {
180                        column_reference_segment.raw().to_lowercase()
181                            + &array_accessor_segment.raw().to_uppercase()
182                    }
183                    None => column_reference_segment.raw().to_uppercase(),
184                };
185
186                if !else_clauses.is_empty() {
187                    let else_expression = else_clauses
188                        .children_where(|it| it.is_type(SyntaxKind::Expression))[0]
189                        .clone();
190
191                    let (coalesce_arg_1, coalesce_arg_2) = if !is_not_prefix
192                        && column_reference_segment_raw_upper
193                            == else_expression.raw().to_uppercase_smolstr()
194                    {
195                        (else_expression, then_expression)
196                    } else if is_not_prefix
197                        && column_reference_segment_raw_upper
198                            == then_expression.raw().to_uppercase_smolstr()
199                    {
200                        (then_expression, else_expression)
201                    } else {
202                        return Vec::new();
203                    };
204
205                    if coalesce_arg_2.raw().eq_ignore_ascii_case("NULL") {
206                        let fixes =
207                            Self::column_only_fix_list(context, column_reference_segment.clone());
208                        return vec![LintResult::new(
209                            condition_expression.into(),
210                            fixes,
211                            Some(String::new()),
212                            None,
213                        )];
214                    }
215
216                    let fixes =
217                        Self::coalesce_fix_list(context, coalesce_arg_1, coalesce_arg_2, false);
218
219                    return vec![LintResult::new(
220                        condition_expression.into(),
221                        fixes,
222                        "Unnecessary CASE statement. Use COALESCE function instead."
223                            .to_owned()
224                            .into(),
225                        None,
226                    )];
227                } else if column_reference_segment
228                    .raw()
229                    .eq_ignore_ascii_case(then_expression.raw())
230                {
231                    let fixes =
232                        Self::column_only_fix_list(context, column_reference_segment.clone());
233
234                    return vec![LintResult::new(
235                        condition_expression.into(),
236                        fixes,
237                        format!(
238                            "Unnecessary CASE statement. Just use column '{}'.",
239                            column_reference_segment.raw()
240                        )
241                        .into(),
242                        None,
243                    )];
244                }
245            }
246
247            Vec::new()
248        } else {
249            Vec::new()
250        }
251    }
252
253    fn crawl_behaviour(&self) -> Crawler {
254        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::CaseExpression]) }).into()
255    }
256}
257
258impl RuleST02 {
259    fn coalesce_fix_list(
260        context: &RuleContext,
261        coalesce_arg_1: ErasedSegment,
262        coalesce_arg_2: ErasedSegment,
263        preceding_not: bool,
264    ) -> Vec<LintFix> {
265        let mut edits = vec![
266            SegmentBuilder::token(
267                context.tables.next_id(),
268                "coalesce",
269                SyntaxKind::FunctionNameIdentifier,
270            )
271            .finish(),
272            SegmentBuilder::symbol(context.tables.next_id(), "("),
273            coalesce_arg_1,
274            SegmentBuilder::symbol(context.tables.next_id(), ","),
275            SegmentBuilder::whitespace(context.tables.next_id(), " "),
276            coalesce_arg_2,
277            SegmentBuilder::symbol(context.tables.next_id(), ")"),
278        ];
279
280        if preceding_not {
281            edits = chain(
282                [
283                    SegmentBuilder::keyword(context.tables.next_id(), "not"),
284                    SegmentBuilder::whitespace(context.tables.next_id(), " "),
285                ],
286                edits,
287            )
288            .collect_vec();
289        }
290
291        vec![LintFix::replace(context.segment.clone(), edits, None)]
292    }
293
294    fn column_only_fix_list(
295        context: &RuleContext,
296        column_reference_segment: ErasedSegment,
297    ) -> Vec<LintFix> {
298        vec![LintFix::replace(
299            context.segment.clone(),
300            vec![column_reference_segment],
301            None,
302        )]
303    }
304}