sqruff_lib/rules/convention/
cv11.rs

1use ahash::AHashMap;
2use itertools::chain;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder, Tables};
6use sqruff_lib_core::utils::functional::segments::Segments;
7use strum_macros::{AsRefStr, EnumString};
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Debug, Copy, Clone, AsRefStr, EnumString, PartialEq, Default)]
16#[strum(serialize_all = "snake_case")]
17enum TypeCastingStyle {
18    #[default]
19    Consistent,
20    Cast,
21    Convert,
22    Shorthand,
23    None,
24}
25
26#[derive(Copy, Clone)]
27struct PreviousSkipped;
28
29fn get_children(segments: Segments) -> Segments {
30    segments.children_where(|it: &ErasedSegment| {
31        !it.is_meta()
32            && !matches!(
33                it.get_type(),
34                SyntaxKind::StartBracket
35                    | SyntaxKind::EndBracket
36                    | SyntaxKind::Whitespace
37                    | SyntaxKind::Newline
38                    | SyntaxKind::CastingOperator
39                    | SyntaxKind::Comma
40                    | SyntaxKind::Keyword
41            )
42    })
43}
44
45fn shorthand_fix_list(
46    tables: &Tables,
47    root_segment: ErasedSegment,
48    shorthand_arg_1: ErasedSegment,
49    shorthand_arg_2: ErasedSegment,
50) -> Vec<LintFix> {
51    let mut edits = if shorthand_arg_1.get_raw_segments().len() > 1 {
52        vec![
53            SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
54            shorthand_arg_1,
55            SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
56        ]
57    } else {
58        vec![shorthand_arg_1]
59    };
60
61    edits.extend([
62        SegmentBuilder::token(tables.next_id(), "::", SyntaxKind::CastingOperator).finish(),
63        shorthand_arg_2,
64    ]);
65
66    vec![LintFix::replace(root_segment, edits, None)]
67}
68
69#[derive(Clone, Debug, Default)]
70pub struct RuleCV11 {
71    preferred_type_casting_style: TypeCastingStyle,
72}
73
74impl Rule for RuleCV11 {
75    fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
76        Ok(RuleCV11 {
77            preferred_type_casting_style: config["preferred_type_casting_style"]
78                .as_string()
79                .unwrap()
80                .parse()
81                .unwrap(),
82        }
83        .erased())
84    }
85
86    fn name(&self) -> &'static str {
87        "convention.casting_style"
88    }
89
90    fn description(&self) -> &'static str {
91        "Enforce consistent type casting style."
92    }
93
94    fn long_description(&self) -> &'static str {
95        r"
96**Anti-pattern**
97
98Using a mixture of `CONVERT`, `::`, and `CAST` when `preferred_type_casting_style` config is set to `consistent` (default).
99
100```sql
101SELECT
102    CONVERT(int, 1) AS bar,
103    100::int::text,
104    CAST(10 AS text) AS coo
105FROM foo;
106```
107
108**Best Practice**
109
110Use a consistent type casting style.
111
112```sql
113SELECT
114    CAST(1 AS int) AS bar,
115    CAST(CAST(100 AS int) AS text),
116    CAST(10 AS text) AS coo
117FROM foo;
118```
119"
120    }
121
122    fn groups(&self) -> &'static [RuleGroups] {
123        &[RuleGroups::All, RuleGroups::Convention]
124    }
125
126    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
127        // If we're in a templated section, don't consider the current location.
128        // (i.e. if a cast happens in a macro, the end user writing the current
129        // query may not know that or have control over it, so we should just
130        // skip it).
131        if let Some(pos_marker) = &context.segment.get_position_marker()
132            && !pos_marker.is_literal()
133        {
134            return Vec::new();
135        }
136
137        let current_type_casting_style = if context.segment.is_type(SyntaxKind::Function) {
138            let Some(function_name) = context
139                .segment
140                .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
141            else {
142                return Vec::new();
143            };
144            if function_name.raw().eq_ignore_ascii_case("CAST") {
145                TypeCastingStyle::Cast
146            } else if function_name.raw().eq_ignore_ascii_case("CONVERT") {
147                TypeCastingStyle::Convert
148            } else {
149                TypeCastingStyle::None
150            }
151        } else if context.segment.is_type(SyntaxKind::CastExpression) {
152            TypeCastingStyle::Shorthand
153        } else {
154            TypeCastingStyle::None
155        };
156
157        let functional_context = FunctionalContext::new(context);
158        match self.preferred_type_casting_style {
159            TypeCastingStyle::Consistent => {
160                let Some(prior_type_casting_style) = context.try_get::<TypeCastingStyle>() else {
161                    context.set(current_type_casting_style);
162                    return Vec::new();
163                };
164                let previous_skipped = context.try_get::<PreviousSkipped>();
165
166                let mut fixes = Vec::new();
167                match prior_type_casting_style {
168                    TypeCastingStyle::Cast => match current_type_casting_style {
169                        TypeCastingStyle::Convert => {
170                            let bracketed = functional_context
171                                .segment()
172                                .children_where(|it: &ErasedSegment| {
173                                    it.is_type(SyntaxKind::FunctionContents)
174                                })
175                                .children_where(|it: &ErasedSegment| {
176                                    it.is_type(SyntaxKind::Bracketed)
177                                });
178                            let convert_content = get_children(bracketed);
179                            if convert_content.len() > 2 {
180                                if previous_skipped.is_none() {
181                                    context.set(PreviousSkipped);
182                                }
183                                return Vec::new();
184                            }
185
186                            fixes = cast_fix_list(
187                                context.tables,
188                                context.segment.clone(),
189                                &[convert_content[1].clone()],
190                                convert_content[0].clone(),
191                                None,
192                            );
193                        }
194                        TypeCastingStyle::Shorthand => {
195                            let expression_datatype_segment =
196                                get_children(functional_context.segment());
197
198                            fixes = cast_fix_list(
199                                context.tables,
200                                context.segment.clone(),
201                                &[expression_datatype_segment[0].clone()],
202                                expression_datatype_segment[1].clone(),
203                                Some(Segments::from_vec(
204                                    expression_datatype_segment.base[2..].to_vec(),
205                                    None,
206                                )),
207                            )
208                        }
209                        _ => {}
210                    },
211                    TypeCastingStyle::Convert => match current_type_casting_style {
212                        TypeCastingStyle::Cast => {
213                            let bracketed = functional_context
214                                .segment()
215                                .children_where(|it: &ErasedSegment| {
216                                    it.is_type(SyntaxKind::FunctionContents)
217                                })
218                                .children_where(|it: &ErasedSegment| {
219                                    it.is_type(SyntaxKind::Bracketed)
220                                });
221                            let cast_content = get_children(bracketed);
222
223                            if cast_content.len() > 2 {
224                                return Vec::new();
225                            }
226
227                            fixes = convert_fix_list(
228                                context.tables,
229                                context.segment.clone(),
230                                cast_content[1].clone(),
231                                cast_content[0].clone(),
232                                None,
233                            );
234                        }
235                        TypeCastingStyle::Shorthand => {
236                            let expression_datatype_segment =
237                                get_children(functional_context.segment());
238
239                            fixes = convert_fix_list(
240                                context.tables,
241                                context.segment.clone(),
242                                expression_datatype_segment[1].clone(),
243                                expression_datatype_segment[0].clone(),
244                                Some(Segments::from_vec(
245                                    expression_datatype_segment.base[2..].to_vec(),
246                                    None,
247                                )),
248                            );
249                        }
250                        _ => (),
251                    },
252                    TypeCastingStyle::Shorthand => {
253                        if current_type_casting_style == TypeCastingStyle::Cast {
254                            // Get the content of CAST
255                            let bracketed = functional_context
256                                .segment()
257                                .children_where(|it: &ErasedSegment| {
258                                    it.is_type(SyntaxKind::FunctionContents)
259                                })
260                                .children_where(|it: &ErasedSegment| {
261                                    it.is_type(SyntaxKind::Bracketed)
262                                });
263                            let cast_content = get_children(bracketed);
264                            if cast_content.len() > 2 {
265                                return Vec::new();
266                            }
267
268                            fixes = shorthand_fix_list(
269                                context.tables,
270                                context.segment.clone(),
271                                cast_content[0].clone(),
272                                cast_content[1].clone(),
273                            );
274                        } else if current_type_casting_style == TypeCastingStyle::Convert {
275                            let bracketed = functional_context
276                                .segment()
277                                .children_where(|it: &ErasedSegment| {
278                                    it.is_type(SyntaxKind::FunctionContents)
279                                })
280                                .children_where(|it: &ErasedSegment| {
281                                    it.is_type(SyntaxKind::Bracketed)
282                                });
283                            let convert_content = get_children(bracketed);
284                            if convert_content.len() > 2 {
285                                return Vec::new();
286                            }
287
288                            fixes = shorthand_fix_list(
289                                context.tables,
290                                context.segment.clone(),
291                                convert_content[1].clone(),
292                                convert_content[0].clone(),
293                            );
294                        }
295                    }
296                    _ => {}
297                }
298
299                if prior_type_casting_style != current_type_casting_style {
300                    return vec![LintResult::new(
301                        context.segment.clone().into(),
302                        fixes,
303                        "Inconsistent type casting styles found.".to_owned().into(),
304                        None,
305                    )];
306                }
307            }
308            _ if current_type_casting_style != self.preferred_type_casting_style => {
309                let mut convert_content = None;
310                let mut cast_content = None;
311                let mut fixes = Vec::new();
312
313                match self.preferred_type_casting_style {
314                    TypeCastingStyle::Cast => match current_type_casting_style {
315                        TypeCastingStyle::Convert => {
316                            let bracketed = functional_context
317                                .segment()
318                                .children_where(|it: &ErasedSegment| {
319                                    it.is_type(SyntaxKind::FunctionContents)
320                                })
321                                .children_where(|it: &ErasedSegment| {
322                                    it.is_type(SyntaxKind::Bracketed)
323                                });
324                            let segments = get_children(bracketed);
325                            fixes = cast_fix_list(
326                                context.tables,
327                                context.segment.clone(),
328                                &[segments[1].clone()],
329                                segments[0].clone(),
330                                None,
331                            );
332                            convert_content = Some(segments);
333                        }
334                        TypeCastingStyle::Shorthand => {
335                            let expression_datatype_segment =
336                                get_children(functional_context.segment());
337                            let data_type_idx = expression_datatype_segment
338                                .iter()
339                                .position(|seg| seg.is_type(SyntaxKind::DataType))
340                                .unwrap();
341
342                            fixes = cast_fix_list(
343                                context.tables,
344                                context.segment.clone(),
345                                &expression_datatype_segment[..data_type_idx],
346                                expression_datatype_segment[data_type_idx].clone(),
347                                Some(Segments::from_vec(
348                                    expression_datatype_segment.base[data_type_idx + 1..].to_vec(),
349                                    None,
350                                )),
351                            );
352                        }
353                        _ => {}
354                    },
355                    TypeCastingStyle::Convert => match current_type_casting_style {
356                        TypeCastingStyle::Cast => {
357                            let bracketed = functional_context
358                                .segment()
359                                .children_where(|it: &ErasedSegment| {
360                                    it.is_type(SyntaxKind::FunctionContents)
361                                })
362                                .children_where(|it: &ErasedSegment| {
363                                    it.is_type(SyntaxKind::Bracketed)
364                                });
365                            let cast_content = get_children(bracketed);
366
367                            fixes = convert_fix_list(
368                                context.tables,
369                                context.segment.clone(),
370                                cast_content[1].clone(),
371                                cast_content[0].clone(),
372                                None,
373                            );
374                        }
375                        TypeCastingStyle::Shorthand => {
376                            let cast_content = get_children(functional_context.segment());
377
378                            fixes = convert_fix_list(
379                                context.tables,
380                                context.segment.clone(),
381                                cast_content[1].clone(),
382                                cast_content[0].clone(),
383                                Some(Segments::from_vec(cast_content.base[2..].to_vec(), None)),
384                            )
385                        }
386                        _ => {}
387                    },
388                    TypeCastingStyle::Shorthand => match current_type_casting_style {
389                        TypeCastingStyle::Cast => {
390                            let bracketed = functional_context
391                                .segment()
392                                .children_where(|it: &ErasedSegment| {
393                                    it.is_type(SyntaxKind::FunctionContents)
394                                })
395                                .children_where(|it: &ErasedSegment| {
396                                    it.is_type(SyntaxKind::Bracketed)
397                                });
398                            let segments = get_children(bracketed);
399
400                            fixes = shorthand_fix_list(
401                                context.tables,
402                                context.segment.clone(),
403                                segments[0].clone(),
404                                segments[1].clone(),
405                            );
406                            cast_content = Some(segments);
407                        }
408                        TypeCastingStyle::Convert => {
409                            let bracketed = functional_context
410                                .segment()
411                                .children_where(|it: &ErasedSegment| {
412                                    it.is_type(SyntaxKind::FunctionContents)
413                                })
414                                .children_where(|it: &ErasedSegment| {
415                                    it.is_type(SyntaxKind::Bracketed)
416                                });
417                            let segments = get_children(bracketed);
418
419                            fixes = shorthand_fix_list(
420                                context.tables,
421                                context.segment.clone(),
422                                segments[1].clone(),
423                                segments[0].clone(),
424                            );
425
426                            convert_content = Some(segments);
427                        }
428                        _ => {}
429                    },
430                    _ => {}
431                }
432
433                if convert_content
434                    .filter(|convert_content| convert_content.len() > 2)
435                    .is_some()
436                {
437                    fixes.clear();
438                }
439
440                if cast_content
441                    .filter(|cast_content| cast_content.len() > 2)
442                    .is_some()
443                {
444                    fixes.clear();
445                }
446
447                return vec![LintResult::new(
448                    context.segment.clone().into(),
449                    fixes,
450                    "Used type casting style is different from the preferred type casting style."
451                        .to_owned()
452                        .into(),
453                    None,
454                )];
455            }
456
457            _ => {}
458        }
459
460        Vec::new()
461    }
462
463    fn is_fix_compatible(&self) -> bool {
464        true
465    }
466
467    fn crawl_behaviour(&self) -> Crawler {
468        SegmentSeekerCrawler::new(
469            const { SyntaxSet::new(&[SyntaxKind::Function, SyntaxKind::CastExpression]) },
470        )
471        .into()
472    }
473}
474
475fn convert_fix_list(
476    tables: &Tables,
477    root: ErasedSegment,
478    convert_arg_1: ErasedSegment,
479    convert_arg_2: ErasedSegment,
480    later_types: Option<Segments>,
481) -> Vec<LintFix> {
482    use sqruff_lib_core::parser::segments::ErasedSegment;
483
484    let mut edits: Vec<ErasedSegment> = vec![
485        SegmentBuilder::token(
486            tables.next_id(),
487            "convert",
488            SyntaxKind::FunctionNameIdentifier,
489        )
490        .finish(),
491        SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
492        convert_arg_1,
493        SegmentBuilder::token(tables.next_id(), ",", SyntaxKind::Comma).finish(),
494        SegmentBuilder::whitespace(tables.next_id(), " "),
495        convert_arg_2,
496        SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
497    ];
498
499    if let Some(later_types) = later_types {
500        let pre_edits: Vec<ErasedSegment> = vec![
501            SegmentBuilder::token(
502                tables.next_id(),
503                "convert",
504                SyntaxKind::FunctionNameIdentifier,
505            )
506            .finish(),
507            SegmentBuilder::symbol(tables.next_id(), "("),
508        ];
509
510        let in_edits: Vec<ErasedSegment> = vec![
511            SegmentBuilder::symbol(tables.next_id(), ","),
512            SegmentBuilder::whitespace(tables.next_id(), " "),
513        ];
514
515        let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
516
517        for _type in later_types.base {
518            edits = chain(
519                chain(pre_edits.clone(), vec![_type]),
520                chain(in_edits.clone(), chain(edits, post_edits.clone())),
521            )
522            .collect();
523        }
524    }
525
526    vec![LintFix::replace(root, edits, None)]
527}
528
529fn cast_fix_list(
530    tables: &Tables,
531    root: ErasedSegment,
532    cast_arg_1: &[ErasedSegment],
533    cast_arg_2: ErasedSegment,
534    later_types: Option<Segments>,
535) -> Vec<LintFix> {
536    let mut edits = vec![
537        SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
538            .finish(),
539        SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
540    ];
541    edits.extend_from_slice(cast_arg_1);
542    edits.extend([
543        SegmentBuilder::whitespace(tables.next_id(), " "),
544        SegmentBuilder::keyword(tables.next_id(), "as"),
545        SegmentBuilder::whitespace(tables.next_id(), " "),
546        cast_arg_2,
547        SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
548    ]);
549
550    if let Some(later_types) = later_types {
551        let pre_edits: Vec<ErasedSegment> = vec![
552            SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
553                .finish(),
554            SegmentBuilder::symbol(tables.next_id(), "("),
555        ];
556
557        let in_edits: Vec<ErasedSegment> = vec![
558            SegmentBuilder::whitespace(tables.next_id(), " "),
559            SegmentBuilder::keyword(tables.next_id(), "as"),
560            SegmentBuilder::whitespace(tables.next_id(), " "),
561        ];
562
563        let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
564
565        for _type in later_types.base {
566            let mut xs = Vec::new();
567            xs.extend(pre_edits.clone());
568            xs.extend(edits);
569            xs.extend(in_edits.clone());
570            xs.push(_type);
571            xs.extend(post_edits.clone());
572            edits = xs;
573        }
574    }
575
576    vec![LintFix::replace(root, edits, None)]
577}