sqruff_lib/rules/structure/
st04.rs

1use ahash::AHashMap;
2use itertools::Itertools;
3use smol_str::ToSmolStr;
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder, Tables};
7use sqruff_lib_core::utils::functional::segments::Segments;
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased as _, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14use crate::utils::reflow::reindent::{IndentUnit, construct_single_indent};
15
16#[derive(Clone, Debug, Default)]
17pub struct RuleST04;
18
19impl Rule for RuleST04 {
20    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
21        Ok(RuleST04.erased())
22    }
23
24    fn name(&self) -> &'static str {
25        "structure.nested_case"
26    }
27
28    fn description(&self) -> &'static str {
29        "Nested ``CASE`` statement in ``ELSE`` clause could be flattened."
30    }
31
32    fn long_description(&self) -> &'static str {
33        r"
34## Anti-pattern
35
36In this example, the outer `CASE`'s `ELSE` is an unnecessary, nested `CASE`.
37
38```sql
39SELECT
40  CASE
41    WHEN species = 'Cat' THEN 'Meow'
42    ELSE
43    CASE
44       WHEN species = 'Dog' THEN 'Woof'
45    END
46  END as sound
47FROM mytable
48```
49
50## Best practice
51
52Move the body of the inner `CASE` to the end of the outer one.
53
54```sql
55SELECT
56  CASE
57    WHEN species = 'Cat' THEN 'Meow'
58    WHEN species = 'Dog' THEN 'Woof'
59  END AS sound
60FROM mytable
61```
62"
63    }
64
65    fn groups(&self) -> &'static [RuleGroups] {
66        &[RuleGroups::All, RuleGroups::Structure]
67    }
68
69    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
70        let segment = FunctionalContext::new(context).segment();
71        let case1_children = segment.children_all();
72        let case1_keywords =
73            case1_children.find_first_where(|it: &ErasedSegment| it.is_keyword("CASE"));
74        let case1_first_case = case1_keywords.first().unwrap();
75        let case1_when_list = case1_children.find_first_where(|it: &ErasedSegment| {
76            matches!(
77                it.get_type(),
78                SyntaxKind::WhenClause | SyntaxKind::ElseClause
79            )
80        });
81        let case1_first_when = case1_when_list.first().unwrap();
82        let when_clause_list =
83            case1_children.find_last_where(|it| it.is_type(SyntaxKind::WhenClause));
84        let case1_last_when = when_clause_list.first();
85        let case1_else_clause =
86            case1_children.find_last_where(|it| it.is_type(SyntaxKind::ElseClause));
87        let case1_else_expressions =
88            case1_else_clause.children_where(|it| it.is_type(SyntaxKind::Expression));
89        let case2 = case1_else_expressions.children_all();
90        let case2_children = case2.children_all();
91        let case2_case_list =
92            case2_children.find_first_where(|it: &ErasedSegment| it.is_keyword("CASE"));
93        let case2_first_case = case2_case_list.first();
94        let case2_when_list = case2_children.find_first_where(|it: &ErasedSegment| {
95            matches!(
96                it.get_type(),
97                SyntaxKind::WhenClause | SyntaxKind::ElseClause
98            )
99        });
100        let case2_first_when = case2_when_list.first();
101
102        let Some(case1_last_when) = case1_last_when else {
103            return Vec::new();
104        };
105        if case1_else_expressions.len() > 1 || case2.len() > 1 || case2.is_empty() {
106            return Vec::new();
107        }
108
109        // Check if case2 actually contains a CASE expression
110        // If there's no nested CASE, we shouldn't proceed with flattening
111        let Some(case2_first_case) = case2_first_case else {
112            return Vec::new();
113        };
114
115        // Additionally check that case2 is actually a CASE expression
116        if !case2.any_match(|seg: &ErasedSegment| seg.is_type(SyntaxKind::CaseExpression)) {
117            return Vec::new();
118        }
119
120        let x1 = segment
121            .children_where(|it| it.is_code())
122            .between_exclusive(case1_first_case, case1_first_when)
123            .into_iter()
124            .map(|it| it.raw().to_smolstr());
125
126        let code2 = case2.children_where(|it| it.is_code());
127        let range2 = if let Some(stop) = case2_first_when {
128            code2.between_exclusive(case2_first_case, stop)
129        } else {
130            code2.after(case2_first_case)
131        };
132        let x2 = range2.into_iter().map(|it| it.raw().to_smolstr());
133
134        if x1.ne(x2) {
135            return Vec::new();
136        }
137
138        let case1_else_clause_seg = case1_else_clause.first().unwrap();
139
140        let case1_to_delete =
141            case1_children.between_exclusive(case1_last_when, case1_else_clause_seg);
142
143        let comments = case1_to_delete.find_last_where(|it: &ErasedSegment| it.is_comment());
144        let after_last_comment_index = comments
145            .first()
146            .and_then(|comment| case1_to_delete.iter().position(|it| it == comment))
147            .map_or(0, |n| n + 1);
148
149        let case1_comments_to_restore =
150            if let Some(stop_seg) = case1_to_delete.base.get(after_last_comment_index) {
151                case1_to_delete.before(stop_seg)
152            } else {
153                case1_to_delete.clone()
154            };
155        let after_else_comment = {
156            let children = case1_else_clause.children_all();
157            let range = if let Some(stop) = case1_else_expressions.first() {
158                children.before(stop)
159            } else {
160                children
161            };
162            range.filter(|it: &ErasedSegment| {
163                matches!(
164                    it.get_type(),
165                    SyntaxKind::Newline
166                        | SyntaxKind::InlineComment
167                        | SyntaxKind::BlockComment
168                        | SyntaxKind::Comment
169                        | SyntaxKind::Whitespace
170                )
171            })
172        };
173
174        let mut fixes = case1_to_delete
175            .into_iter()
176            .map(LintFix::delete)
177            .collect_vec();
178
179        let tab_space_size = context.config.raw["indentation"]["tab_space_size"]
180            .as_int()
181            .unwrap() as usize;
182        let indent_unit = context.config.raw["indentation"]["indent_unit"]
183            .as_string()
184            .unwrap();
185        let indent_unit = IndentUnit::from_type_and_size(indent_unit, tab_space_size);
186
187        let when_indent_str = indentation(&case1_children, case1_last_when, indent_unit);
188        let end_indent_str = indentation(&case1_children, case1_first_case, indent_unit);
189
190        let nested_clauses = case2.children_where(|it: &ErasedSegment| {
191            matches!(
192                it.get_type(),
193                SyntaxKind::WhenClause
194                    | SyntaxKind::ElseClause
195                    | SyntaxKind::Newline
196                    | SyntaxKind::InlineComment
197                    | SyntaxKind::BlockComment
198                    | SyntaxKind::Comment
199                    | SyntaxKind::Whitespace
200            )
201        });
202
203        let mut segments = case1_comments_to_restore.base;
204        segments.append(&mut rebuild_spacing(
205            context.tables,
206            &when_indent_str,
207            after_else_comment,
208        ));
209        segments.append(&mut rebuild_spacing(
210            context.tables,
211            &when_indent_str,
212            nested_clauses,
213        ));
214
215        fixes.push(LintFix::create_after(
216            case1_last_when.clone(),
217            segments,
218            None,
219        ));
220        fixes.push(LintFix::delete(case1_else_clause_seg.clone()));
221        fixes.append(&mut nested_end_trailing_comment(
222            context.tables,
223            case1_children,
224            case1_else_clause_seg,
225            &end_indent_str,
226        ));
227
228        vec![LintResult::new(case2.first().cloned(), fixes, None, None)]
229    }
230
231    fn is_fix_compatible(&self) -> bool {
232        true
233    }
234
235    fn crawl_behaviour(&self) -> Crawler {
236        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::CaseExpression]) }).into()
237    }
238}
239
240fn indentation(
241    parent_segments: &Segments,
242    segment: &ErasedSegment,
243    indent_unit: IndentUnit,
244) -> String {
245    let leading_whitespace = parent_segments
246        .before(segment)
247        .reversed()
248        .find_first_where(|it: &ErasedSegment| it.is_type(SyntaxKind::Whitespace));
249    let seg_indent = parent_segments
250        .before(segment)
251        .find_last_where(|it| it.is_type(SyntaxKind::Indent));
252    let mut indent_level = 1;
253    if let Some(segment_indent) = seg_indent
254        .last()
255        .filter(|segment_indent| segment_indent.is_indent())
256    {
257        indent_level = segment_indent.indent_val() as usize + 1;
258    }
259
260    if let Some(whitespace_seg) = leading_whitespace.first() {
261        if !leading_whitespace.is_empty() && whitespace_seg.raw().len() > 1 {
262            leading_whitespace
263                .iter()
264                .map(|seg| seg.raw().to_string())
265                .collect::<String>()
266        } else {
267            construct_single_indent(indent_unit).repeat(indent_level)
268        }
269    } else {
270        construct_single_indent(indent_unit).repeat(indent_level)
271    }
272}
273
274fn rebuild_spacing(
275    tables: &Tables,
276    indent_str: &str,
277    nested_clauses: Segments,
278) -> Vec<ErasedSegment> {
279    let mut buff = Vec::new();
280
281    let mut prior_newline = nested_clauses
282        .find_last_where(|it: &ErasedSegment| !it.is_whitespace())
283        .any_match(|it: &ErasedSegment| it.is_comment());
284    let mut prior_whitespace = String::new();
285
286    for seg in nested_clauses {
287        if matches!(
288            seg.get_type(),
289            SyntaxKind::WhenClause | SyntaxKind::ElseClause
290        ) || (prior_newline && seg.is_comment())
291        {
292            buff.push(SegmentBuilder::newline(tables.next_id(), "\n"));
293            buff.push(SegmentBuilder::whitespace(tables.next_id(), indent_str));
294            buff.push(seg.clone());
295            prior_newline = false;
296            prior_whitespace.clear();
297        } else if seg.is_type(SyntaxKind::Newline) {
298            prior_newline = true;
299            prior_whitespace.clear();
300        } else if !prior_newline && seg.is_comment() {
301            buff.push(SegmentBuilder::whitespace(
302                tables.next_id(),
303                &prior_whitespace,
304            ));
305            buff.push(seg.clone());
306            prior_newline = false;
307            prior_whitespace.clear();
308        } else if seg.is_whitespace() {
309            prior_whitespace = seg.raw().to_string();
310        }
311    }
312
313    buff
314}
315
316fn nested_end_trailing_comment(
317    tables: &Tables,
318    case1_children: Segments,
319    case1_else_clause_seg: &ErasedSegment,
320    end_indent_str: &str,
321) -> Vec<LintFix> {
322    // Prepend newline spacing to comments on the final nested `END` line.
323    let trailing_end = case1_children
324        .after(case1_else_clause_seg)
325        .take_while(|seg: &ErasedSegment| !seg.is_type(SyntaxKind::Newline));
326
327    let mut fixes = trailing_end
328        .take_while(|seg: &ErasedSegment| !seg.is_comment())
329        .filter(|seg: &ErasedSegment| seg.is_whitespace())
330        .into_iter()
331        .map(LintFix::delete)
332        .collect_vec();
333
334    if let Some(first_comment) = trailing_end
335        .find_first_where(|seg: &ErasedSegment| seg.is_comment())
336        .first()
337    {
338        let segments = vec![
339            SegmentBuilder::newline(tables.next_id(), "\n"),
340            SegmentBuilder::whitespace(tables.next_id(), end_indent_str),
341        ];
342        fixes.push(LintFix::create_before(first_comment.clone(), segments));
343    }
344
345    fixes
346}