sqruff_lib/rules/convention/
cv11.rs

1use ahash::AHashMap;
2use itertools::chain;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder, Tables};
6use sqruff_lib_core::utils::functional::segments::Segments;
7use strum_macros::{AsRefStr, EnumString};
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Debug, Copy, Clone, AsRefStr, EnumString, PartialEq, Default)]
16#[strum(serialize_all = "snake_case")]
17enum TypeCastingStyle {
18    #[default]
19    Consistent,
20    Cast,
21    Convert,
22    Shorthand,
23    None,
24}
25
26#[derive(Copy, Clone)]
27struct PreviousSkipped;
28
29fn get_children(segments: Segments) -> Segments {
30    segments.children_where(|it: &ErasedSegment| {
31        !it.is_meta()
32            && !matches!(
33                it.get_type(),
34                SyntaxKind::StartBracket
35                    | SyntaxKind::EndBracket
36                    | SyntaxKind::Whitespace
37                    | SyntaxKind::Newline
38                    | SyntaxKind::CastingOperator
39                    | SyntaxKind::Comma
40                    | SyntaxKind::Keyword
41            )
42    })
43}
44
45fn shorthand_fix_list(
46    tables: &Tables,
47    root_segment: ErasedSegment,
48    shorthand_arg_1: ErasedSegment,
49    shorthand_arg_2: ErasedSegment,
50) -> Vec<LintFix> {
51    let mut edits = if shorthand_arg_1.get_raw_segments().len() > 1 {
52        vec![
53            SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
54            shorthand_arg_1,
55            SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
56        ]
57    } else {
58        vec![shorthand_arg_1]
59    };
60
61    edits.extend([
62        SegmentBuilder::token(tables.next_id(), "::", SyntaxKind::CastingOperator).finish(),
63        shorthand_arg_2,
64    ]);
65
66    vec![LintFix::replace(root_segment, edits, None)]
67}
68
69#[derive(Clone, Debug, Default)]
70pub struct RuleCV11 {
71    preferred_type_casting_style: TypeCastingStyle,
72}
73
74impl Rule for RuleCV11 {
75    fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
76        Ok(RuleCV11 {
77            preferred_type_casting_style: config["preferred_type_casting_style"]
78                .as_string()
79                .unwrap()
80                .parse()
81                .unwrap(),
82        }
83        .erased())
84    }
85
86    fn name(&self) -> &'static str {
87        "convention.casting_style"
88    }
89
90    fn description(&self) -> &'static str {
91        "Enforce consistent type casting style."
92    }
93
94    fn long_description(&self) -> &'static str {
95        r"
96**Anti-pattern**
97
98Using a mixture of `CONVERT`, `::`, and `CAST` when `preferred_type_casting_style` config is set to `consistent` (default).
99
100```sql
101SELECT
102    CONVERT(int, 1) AS bar,
103    100::int::text,
104    CAST(10 AS text) AS coo
105FROM foo;
106```
107
108**Best Practice**
109
110Use a consistent type casting style.
111
112```sql
113SELECT
114    CAST(1 AS int) AS bar,
115    CAST(CAST(100 AS int) AS text),
116    CAST(10 AS text) AS coo
117FROM foo;
118```
119"
120    }
121
122    fn groups(&self) -> &'static [RuleGroups] {
123        &[RuleGroups::All, RuleGroups::Convention]
124    }
125
126    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
127        // If we're in a templated section, don't consider the current location.
128        // (i.e. if a cast happens in a macro, the end user writing the current
129        // query may not know that or have control over it, so we should just
130        // skip it).
131        if let Some(pos_marker) = &context.segment.get_position_marker()
132            && !pos_marker.is_literal()
133        {
134            return Vec::new();
135        }
136
137        let current_type_casting_style = if context.segment.is_type(SyntaxKind::Function) {
138            let Some(function_name) = context
139                .segment
140                .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
141            else {
142                return Vec::new();
143            };
144            if function_name.raw().eq_ignore_ascii_case("CAST") {
145                TypeCastingStyle::Cast
146            } else if function_name.raw().eq_ignore_ascii_case("CONVERT") {
147                TypeCastingStyle::Convert
148            } else {
149                TypeCastingStyle::None
150            }
151        } else if context.segment.is_type(SyntaxKind::CastExpression) {
152            TypeCastingStyle::Shorthand
153        } else {
154            TypeCastingStyle::None
155        };
156
157        let functional_context = FunctionalContext::new(context);
158        match self.preferred_type_casting_style {
159            TypeCastingStyle::Consistent => {
160                // If current is None, it's not a cast operation (e.g., STRING_AGG, or
161                // other non-CAST/CONVERT functions), so skip it entirely.
162                if current_type_casting_style == TypeCastingStyle::None {
163                    return Vec::new();
164                }
165
166                let Some(prior_type_casting_style) = context.try_get::<TypeCastingStyle>() else {
167                    context.set(current_type_casting_style);
168                    return Vec::new();
169                };
170                let previous_skipped = context.try_get::<PreviousSkipped>();
171
172                let mut fixes = Vec::new();
173                match prior_type_casting_style {
174                    TypeCastingStyle::Cast => match current_type_casting_style {
175                        TypeCastingStyle::Convert => {
176                            let bracketed = functional_context
177                                .segment()
178                                .children_where(|it: &ErasedSegment| {
179                                    it.is_type(SyntaxKind::FunctionContents)
180                                })
181                                .children_where(|it: &ErasedSegment| {
182                                    it.is_type(SyntaxKind::Bracketed)
183                                });
184                            let convert_content = get_children(bracketed);
185                            if convert_content.len() > 2 {
186                                if previous_skipped.is_none() {
187                                    context.set(PreviousSkipped);
188                                }
189                                return Vec::new();
190                            }
191
192                            fixes = cast_fix_list(
193                                context.tables,
194                                context.segment.clone(),
195                                &[convert_content[1].clone()],
196                                convert_content[0].clone(),
197                                None,
198                            );
199                        }
200                        TypeCastingStyle::Shorthand => {
201                            let expression_datatype_segment =
202                                get_children(functional_context.segment());
203
204                            fixes = cast_fix_list(
205                                context.tables,
206                                context.segment.clone(),
207                                &[expression_datatype_segment[0].clone()],
208                                expression_datatype_segment[1].clone(),
209                                Some(Segments::from_vec(
210                                    expression_datatype_segment.base[2..].to_vec(),
211                                    None,
212                                )),
213                            )
214                        }
215                        _ => {}
216                    },
217                    TypeCastingStyle::Convert => match current_type_casting_style {
218                        TypeCastingStyle::Cast => {
219                            let bracketed = functional_context
220                                .segment()
221                                .children_where(|it: &ErasedSegment| {
222                                    it.is_type(SyntaxKind::FunctionContents)
223                                })
224                                .children_where(|it: &ErasedSegment| {
225                                    it.is_type(SyntaxKind::Bracketed)
226                                });
227                            let cast_content = get_children(bracketed);
228
229                            if cast_content.len() > 2 {
230                                return Vec::new();
231                            }
232
233                            fixes = convert_fix_list(
234                                context.tables,
235                                context.segment.clone(),
236                                cast_content[1].clone(),
237                                cast_content[0].clone(),
238                                None,
239                            );
240                        }
241                        TypeCastingStyle::Shorthand => {
242                            let expression_datatype_segment =
243                                get_children(functional_context.segment());
244
245                            fixes = convert_fix_list(
246                                context.tables,
247                                context.segment.clone(),
248                                expression_datatype_segment[1].clone(),
249                                expression_datatype_segment[0].clone(),
250                                Some(Segments::from_vec(
251                                    expression_datatype_segment.base[2..].to_vec(),
252                                    None,
253                                )),
254                            );
255                        }
256                        _ => (),
257                    },
258                    TypeCastingStyle::Shorthand => {
259                        if current_type_casting_style == TypeCastingStyle::Cast {
260                            // Get the content of CAST
261                            let bracketed = functional_context
262                                .segment()
263                                .children_where(|it: &ErasedSegment| {
264                                    it.is_type(SyntaxKind::FunctionContents)
265                                })
266                                .children_where(|it: &ErasedSegment| {
267                                    it.is_type(SyntaxKind::Bracketed)
268                                });
269                            let cast_content = get_children(bracketed);
270                            if cast_content.len() > 2 {
271                                return Vec::new();
272                            }
273
274                            fixes = shorthand_fix_list(
275                                context.tables,
276                                context.segment.clone(),
277                                cast_content[0].clone(),
278                                cast_content[1].clone(),
279                            );
280                        } else if current_type_casting_style == TypeCastingStyle::Convert {
281                            let bracketed = functional_context
282                                .segment()
283                                .children_where(|it: &ErasedSegment| {
284                                    it.is_type(SyntaxKind::FunctionContents)
285                                })
286                                .children_where(|it: &ErasedSegment| {
287                                    it.is_type(SyntaxKind::Bracketed)
288                                });
289                            let convert_content = get_children(bracketed);
290                            if convert_content.len() > 2 {
291                                return Vec::new();
292                            }
293
294                            fixes = shorthand_fix_list(
295                                context.tables,
296                                context.segment.clone(),
297                                convert_content[1].clone(),
298                                convert_content[0].clone(),
299                            );
300                        }
301                    }
302                    _ => {}
303                }
304
305                if prior_type_casting_style != current_type_casting_style {
306                    return vec![LintResult::new(
307                        context.segment.clone().into(),
308                        fixes,
309                        "Inconsistent type casting styles found.".to_owned().into(),
310                        None,
311                    )];
312                }
313            }
314            _ if current_type_casting_style != self.preferred_type_casting_style => {
315                let mut convert_content = None;
316                let mut cast_content = None;
317                let mut fixes = Vec::new();
318
319                match self.preferred_type_casting_style {
320                    TypeCastingStyle::Cast => match current_type_casting_style {
321                        TypeCastingStyle::Convert => {
322                            let bracketed = functional_context
323                                .segment()
324                                .children_where(|it: &ErasedSegment| {
325                                    it.is_type(SyntaxKind::FunctionContents)
326                                })
327                                .children_where(|it: &ErasedSegment| {
328                                    it.is_type(SyntaxKind::Bracketed)
329                                });
330                            let segments = get_children(bracketed);
331                            fixes = cast_fix_list(
332                                context.tables,
333                                context.segment.clone(),
334                                &[segments[1].clone()],
335                                segments[0].clone(),
336                                None,
337                            );
338                            convert_content = Some(segments);
339                        }
340                        TypeCastingStyle::Shorthand => {
341                            let expression_datatype_segment =
342                                get_children(functional_context.segment());
343                            let data_type_idx = expression_datatype_segment
344                                .iter()
345                                .position(|seg| seg.is_type(SyntaxKind::DataType))
346                                .unwrap();
347
348                            fixes = cast_fix_list(
349                                context.tables,
350                                context.segment.clone(),
351                                &expression_datatype_segment[..data_type_idx],
352                                expression_datatype_segment[data_type_idx].clone(),
353                                Some(Segments::from_vec(
354                                    expression_datatype_segment.base[data_type_idx + 1..].to_vec(),
355                                    None,
356                                )),
357                            );
358                        }
359                        _ => {}
360                    },
361                    TypeCastingStyle::Convert => match current_type_casting_style {
362                        TypeCastingStyle::Cast => {
363                            let bracketed = functional_context
364                                .segment()
365                                .children_where(|it: &ErasedSegment| {
366                                    it.is_type(SyntaxKind::FunctionContents)
367                                })
368                                .children_where(|it: &ErasedSegment| {
369                                    it.is_type(SyntaxKind::Bracketed)
370                                });
371                            let cast_content = get_children(bracketed);
372
373                            fixes = convert_fix_list(
374                                context.tables,
375                                context.segment.clone(),
376                                cast_content[1].clone(),
377                                cast_content[0].clone(),
378                                None,
379                            );
380                        }
381                        TypeCastingStyle::Shorthand => {
382                            let cast_content = get_children(functional_context.segment());
383
384                            fixes = convert_fix_list(
385                                context.tables,
386                                context.segment.clone(),
387                                cast_content[1].clone(),
388                                cast_content[0].clone(),
389                                Some(Segments::from_vec(cast_content.base[2..].to_vec(), None)),
390                            )
391                        }
392                        _ => {}
393                    },
394                    TypeCastingStyle::Shorthand => match current_type_casting_style {
395                        TypeCastingStyle::Cast => {
396                            let bracketed = functional_context
397                                .segment()
398                                .children_where(|it: &ErasedSegment| {
399                                    it.is_type(SyntaxKind::FunctionContents)
400                                })
401                                .children_where(|it: &ErasedSegment| {
402                                    it.is_type(SyntaxKind::Bracketed)
403                                });
404                            let segments = get_children(bracketed);
405
406                            fixes = shorthand_fix_list(
407                                context.tables,
408                                context.segment.clone(),
409                                segments[0].clone(),
410                                segments[1].clone(),
411                            );
412                            cast_content = Some(segments);
413                        }
414                        TypeCastingStyle::Convert => {
415                            let bracketed = functional_context
416                                .segment()
417                                .children_where(|it: &ErasedSegment| {
418                                    it.is_type(SyntaxKind::FunctionContents)
419                                })
420                                .children_where(|it: &ErasedSegment| {
421                                    it.is_type(SyntaxKind::Bracketed)
422                                });
423                            let segments = get_children(bracketed);
424
425                            fixes = shorthand_fix_list(
426                                context.tables,
427                                context.segment.clone(),
428                                segments[1].clone(),
429                                segments[0].clone(),
430                            );
431
432                            convert_content = Some(segments);
433                        }
434                        _ => {}
435                    },
436                    _ => {}
437                }
438
439                if convert_content
440                    .filter(|convert_content| convert_content.len() > 2)
441                    .is_some()
442                {
443                    fixes.clear();
444                }
445
446                if cast_content
447                    .filter(|cast_content| cast_content.len() > 2)
448                    .is_some()
449                {
450                    fixes.clear();
451                }
452
453                return vec![LintResult::new(
454                    context.segment.clone().into(),
455                    fixes,
456                    "Used type casting style is different from the preferred type casting style."
457                        .to_owned()
458                        .into(),
459                    None,
460                )];
461            }
462
463            _ => {}
464        }
465
466        Vec::new()
467    }
468
469    fn is_fix_compatible(&self) -> bool {
470        true
471    }
472
473    fn crawl_behaviour(&self) -> Crawler {
474        SegmentSeekerCrawler::new(
475            const { SyntaxSet::new(&[SyntaxKind::Function, SyntaxKind::CastExpression]) },
476        )
477        .into()
478    }
479}
480
481fn convert_fix_list(
482    tables: &Tables,
483    root: ErasedSegment,
484    convert_arg_1: ErasedSegment,
485    convert_arg_2: ErasedSegment,
486    later_types: Option<Segments>,
487) -> Vec<LintFix> {
488    use sqruff_lib_core::parser::segments::ErasedSegment;
489
490    let mut edits: Vec<ErasedSegment> = vec![
491        SegmentBuilder::token(
492            tables.next_id(),
493            "convert",
494            SyntaxKind::FunctionNameIdentifier,
495        )
496        .finish(),
497        SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
498        convert_arg_1,
499        SegmentBuilder::token(tables.next_id(), ",", SyntaxKind::Comma).finish(),
500        SegmentBuilder::whitespace(tables.next_id(), " "),
501        convert_arg_2,
502        SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
503    ];
504
505    if let Some(later_types) = later_types {
506        let pre_edits: Vec<ErasedSegment> = vec![
507            SegmentBuilder::token(
508                tables.next_id(),
509                "convert",
510                SyntaxKind::FunctionNameIdentifier,
511            )
512            .finish(),
513            SegmentBuilder::symbol(tables.next_id(), "("),
514        ];
515
516        let in_edits: Vec<ErasedSegment> = vec![
517            SegmentBuilder::symbol(tables.next_id(), ","),
518            SegmentBuilder::whitespace(tables.next_id(), " "),
519        ];
520
521        let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
522
523        for _type in later_types.base {
524            edits = chain(
525                chain(pre_edits.clone(), vec![_type]),
526                chain(in_edits.clone(), chain(edits, post_edits.clone())),
527            )
528            .collect();
529        }
530    }
531
532    vec![LintFix::replace(root, edits, None)]
533}
534
535fn cast_fix_list(
536    tables: &Tables,
537    root: ErasedSegment,
538    cast_arg_1: &[ErasedSegment],
539    cast_arg_2: ErasedSegment,
540    later_types: Option<Segments>,
541) -> Vec<LintFix> {
542    let mut edits = vec![
543        SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
544            .finish(),
545        SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
546    ];
547    edits.extend_from_slice(cast_arg_1);
548    edits.extend([
549        SegmentBuilder::whitespace(tables.next_id(), " "),
550        SegmentBuilder::keyword(tables.next_id(), "as"),
551        SegmentBuilder::whitespace(tables.next_id(), " "),
552        cast_arg_2,
553        SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
554    ]);
555
556    if let Some(later_types) = later_types {
557        let pre_edits: Vec<ErasedSegment> = vec![
558            SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
559                .finish(),
560            SegmentBuilder::symbol(tables.next_id(), "("),
561        ];
562
563        let in_edits: Vec<ErasedSegment> = vec![
564            SegmentBuilder::whitespace(tables.next_id(), " "),
565            SegmentBuilder::keyword(tables.next_id(), "as"),
566            SegmentBuilder::whitespace(tables.next_id(), " "),
567        ];
568
569        let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
570
571        for _type in later_types.base {
572            let mut xs = Vec::new();
573            xs.extend(pre_edits.clone());
574            xs.extend(edits);
575            xs.extend(in_edits.clone());
576            xs.push(_type);
577            xs.extend(post_edits.clone());
578            edits = xs;
579        }
580    }
581
582    vec![LintFix::replace(root, edits, None)]
583}