1use ahash::AHashMap;
2use itertools::chain;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder, Tables};
6use sqruff_lib_core::utils::functional::segments::Segments;
7use strum_macros::{AsRefStr, EnumString};
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Debug, Copy, Clone, AsRefStr, EnumString, PartialEq, Default)]
16#[strum(serialize_all = "snake_case")]
17enum TypeCastingStyle {
18 #[default]
19 Consistent,
20 Cast,
21 Convert,
22 Shorthand,
23 None,
24}
25
26#[derive(Copy, Clone)]
27struct PreviousSkipped;
28
29fn get_children(segments: Segments) -> Segments {
30 segments.children_where(|it: &ErasedSegment| {
31 !it.is_meta()
32 && !matches!(
33 it.get_type(),
34 SyntaxKind::StartBracket
35 | SyntaxKind::EndBracket
36 | SyntaxKind::Whitespace
37 | SyntaxKind::Newline
38 | SyntaxKind::CastingOperator
39 | SyntaxKind::Comma
40 | SyntaxKind::Keyword
41 )
42 })
43}
44
45fn shorthand_fix_list(
46 tables: &Tables,
47 root_segment: ErasedSegment,
48 shorthand_arg_1: ErasedSegment,
49 shorthand_arg_2: ErasedSegment,
50) -> Vec<LintFix> {
51 let mut edits = if shorthand_arg_1.get_raw_segments().len() > 1 {
52 vec![
53 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
54 shorthand_arg_1,
55 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
56 ]
57 } else {
58 vec![shorthand_arg_1]
59 };
60
61 edits.extend([
62 SegmentBuilder::token(tables.next_id(), "::", SyntaxKind::CastingOperator).finish(),
63 shorthand_arg_2,
64 ]);
65
66 vec![LintFix::replace(root_segment, edits, None)]
67}
68
69#[derive(Clone, Debug, Default)]
70pub struct RuleCV11 {
71 preferred_type_casting_style: TypeCastingStyle,
72}
73
74impl Rule for RuleCV11 {
75 fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
76 Ok(RuleCV11 {
77 preferred_type_casting_style: config["preferred_type_casting_style"]
78 .as_string()
79 .unwrap()
80 .parse()
81 .unwrap(),
82 }
83 .erased())
84 }
85
86 fn name(&self) -> &'static str {
87 "convention.casting_style"
88 }
89
90 fn description(&self) -> &'static str {
91 "Enforce consistent type casting style."
92 }
93
94 fn long_description(&self) -> &'static str {
95 r"
96**Anti-pattern**
97
98Using a mixture of `CONVERT`, `::`, and `CAST` when `preferred_type_casting_style` config is set to `consistent` (default).
99
100```sql
101SELECT
102 CONVERT(int, 1) AS bar,
103 100::int::text,
104 CAST(10 AS text) AS coo
105FROM foo;
106```
107
108**Best Practice**
109
110Use a consistent type casting style.
111
112```sql
113SELECT
114 CAST(1 AS int) AS bar,
115 CAST(CAST(100 AS int) AS text),
116 CAST(10 AS text) AS coo
117FROM foo;
118```
119"
120 }
121
122 fn groups(&self) -> &'static [RuleGroups] {
123 &[RuleGroups::All, RuleGroups::Convention]
124 }
125
126 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
127 if let Some(pos_marker) = &context.segment.get_position_marker()
132 && !pos_marker.is_literal()
133 {
134 return Vec::new();
135 }
136
137 let current_type_casting_style = if context.segment.is_type(SyntaxKind::Function) {
138 let Some(function_name) = context
139 .segment
140 .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
141 else {
142 return Vec::new();
143 };
144 if function_name.raw().eq_ignore_ascii_case("CAST") {
145 TypeCastingStyle::Cast
146 } else if function_name.raw().eq_ignore_ascii_case("CONVERT") {
147 TypeCastingStyle::Convert
148 } else {
149 TypeCastingStyle::None
150 }
151 } else if context.segment.is_type(SyntaxKind::CastExpression) {
152 TypeCastingStyle::Shorthand
153 } else {
154 TypeCastingStyle::None
155 };
156
157 let functional_context = FunctionalContext::new(context);
158 match self.preferred_type_casting_style {
159 TypeCastingStyle::Consistent => {
160 if current_type_casting_style == TypeCastingStyle::None {
163 return Vec::new();
164 }
165
166 let Some(prior_type_casting_style) = context.try_get::<TypeCastingStyle>() else {
167 context.set(current_type_casting_style);
168 return Vec::new();
169 };
170 let previous_skipped = context.try_get::<PreviousSkipped>();
171
172 let mut fixes = Vec::new();
173 match prior_type_casting_style {
174 TypeCastingStyle::Cast => match current_type_casting_style {
175 TypeCastingStyle::Convert => {
176 let bracketed = functional_context
177 .segment()
178 .children_where(|it: &ErasedSegment| {
179 it.is_type(SyntaxKind::FunctionContents)
180 })
181 .children_where(|it: &ErasedSegment| {
182 it.is_type(SyntaxKind::Bracketed)
183 });
184 let convert_content = get_children(bracketed);
185 if convert_content.len() > 2 {
186 if previous_skipped.is_none() {
187 context.set(PreviousSkipped);
188 }
189 return Vec::new();
190 }
191
192 fixes = cast_fix_list(
193 context.tables,
194 context.segment.clone(),
195 &[convert_content[1].clone()],
196 convert_content[0].clone(),
197 None,
198 );
199 }
200 TypeCastingStyle::Shorthand => {
201 let expression_datatype_segment =
202 get_children(functional_context.segment());
203
204 fixes = cast_fix_list(
205 context.tables,
206 context.segment.clone(),
207 &[expression_datatype_segment[0].clone()],
208 expression_datatype_segment[1].clone(),
209 Some(Segments::from_vec(
210 expression_datatype_segment.base[2..].to_vec(),
211 None,
212 )),
213 )
214 }
215 _ => {}
216 },
217 TypeCastingStyle::Convert => match current_type_casting_style {
218 TypeCastingStyle::Cast => {
219 let bracketed = functional_context
220 .segment()
221 .children_where(|it: &ErasedSegment| {
222 it.is_type(SyntaxKind::FunctionContents)
223 })
224 .children_where(|it: &ErasedSegment| {
225 it.is_type(SyntaxKind::Bracketed)
226 });
227 let cast_content = get_children(bracketed);
228
229 if cast_content.len() > 2 {
230 return Vec::new();
231 }
232
233 fixes = convert_fix_list(
234 context.tables,
235 context.segment.clone(),
236 cast_content[1].clone(),
237 cast_content[0].clone(),
238 None,
239 );
240 }
241 TypeCastingStyle::Shorthand => {
242 let expression_datatype_segment =
243 get_children(functional_context.segment());
244
245 fixes = convert_fix_list(
246 context.tables,
247 context.segment.clone(),
248 expression_datatype_segment[1].clone(),
249 expression_datatype_segment[0].clone(),
250 Some(Segments::from_vec(
251 expression_datatype_segment.base[2..].to_vec(),
252 None,
253 )),
254 );
255 }
256 _ => (),
257 },
258 TypeCastingStyle::Shorthand => {
259 if current_type_casting_style == TypeCastingStyle::Cast {
260 let bracketed = functional_context
262 .segment()
263 .children_where(|it: &ErasedSegment| {
264 it.is_type(SyntaxKind::FunctionContents)
265 })
266 .children_where(|it: &ErasedSegment| {
267 it.is_type(SyntaxKind::Bracketed)
268 });
269 let cast_content = get_children(bracketed);
270 if cast_content.len() > 2 {
271 return Vec::new();
272 }
273
274 fixes = shorthand_fix_list(
275 context.tables,
276 context.segment.clone(),
277 cast_content[0].clone(),
278 cast_content[1].clone(),
279 );
280 } else if current_type_casting_style == TypeCastingStyle::Convert {
281 let bracketed = functional_context
282 .segment()
283 .children_where(|it: &ErasedSegment| {
284 it.is_type(SyntaxKind::FunctionContents)
285 })
286 .children_where(|it: &ErasedSegment| {
287 it.is_type(SyntaxKind::Bracketed)
288 });
289 let convert_content = get_children(bracketed);
290 if convert_content.len() > 2 {
291 return Vec::new();
292 }
293
294 fixes = shorthand_fix_list(
295 context.tables,
296 context.segment.clone(),
297 convert_content[1].clone(),
298 convert_content[0].clone(),
299 );
300 }
301 }
302 _ => {}
303 }
304
305 if prior_type_casting_style != current_type_casting_style {
306 return vec![LintResult::new(
307 context.segment.clone().into(),
308 fixes,
309 "Inconsistent type casting styles found.".to_owned().into(),
310 None,
311 )];
312 }
313 }
314 _ if current_type_casting_style != self.preferred_type_casting_style => {
315 let mut convert_content = None;
316 let mut cast_content = None;
317 let mut fixes = Vec::new();
318
319 match self.preferred_type_casting_style {
320 TypeCastingStyle::Cast => match current_type_casting_style {
321 TypeCastingStyle::Convert => {
322 let bracketed = functional_context
323 .segment()
324 .children_where(|it: &ErasedSegment| {
325 it.is_type(SyntaxKind::FunctionContents)
326 })
327 .children_where(|it: &ErasedSegment| {
328 it.is_type(SyntaxKind::Bracketed)
329 });
330 let segments = get_children(bracketed);
331 fixes = cast_fix_list(
332 context.tables,
333 context.segment.clone(),
334 &[segments[1].clone()],
335 segments[0].clone(),
336 None,
337 );
338 convert_content = Some(segments);
339 }
340 TypeCastingStyle::Shorthand => {
341 let expression_datatype_segment =
342 get_children(functional_context.segment());
343 let data_type_idx = expression_datatype_segment
344 .iter()
345 .position(|seg| seg.is_type(SyntaxKind::DataType))
346 .unwrap();
347
348 fixes = cast_fix_list(
349 context.tables,
350 context.segment.clone(),
351 &expression_datatype_segment[..data_type_idx],
352 expression_datatype_segment[data_type_idx].clone(),
353 Some(Segments::from_vec(
354 expression_datatype_segment.base[data_type_idx + 1..].to_vec(),
355 None,
356 )),
357 );
358 }
359 _ => {}
360 },
361 TypeCastingStyle::Convert => match current_type_casting_style {
362 TypeCastingStyle::Cast => {
363 let bracketed = functional_context
364 .segment()
365 .children_where(|it: &ErasedSegment| {
366 it.is_type(SyntaxKind::FunctionContents)
367 })
368 .children_where(|it: &ErasedSegment| {
369 it.is_type(SyntaxKind::Bracketed)
370 });
371 let cast_content = get_children(bracketed);
372
373 fixes = convert_fix_list(
374 context.tables,
375 context.segment.clone(),
376 cast_content[1].clone(),
377 cast_content[0].clone(),
378 None,
379 );
380 }
381 TypeCastingStyle::Shorthand => {
382 let cast_content = get_children(functional_context.segment());
383
384 fixes = convert_fix_list(
385 context.tables,
386 context.segment.clone(),
387 cast_content[1].clone(),
388 cast_content[0].clone(),
389 Some(Segments::from_vec(cast_content.base[2..].to_vec(), None)),
390 )
391 }
392 _ => {}
393 },
394 TypeCastingStyle::Shorthand => match current_type_casting_style {
395 TypeCastingStyle::Cast => {
396 let bracketed = functional_context
397 .segment()
398 .children_where(|it: &ErasedSegment| {
399 it.is_type(SyntaxKind::FunctionContents)
400 })
401 .children_where(|it: &ErasedSegment| {
402 it.is_type(SyntaxKind::Bracketed)
403 });
404 let segments = get_children(bracketed);
405
406 fixes = shorthand_fix_list(
407 context.tables,
408 context.segment.clone(),
409 segments[0].clone(),
410 segments[1].clone(),
411 );
412 cast_content = Some(segments);
413 }
414 TypeCastingStyle::Convert => {
415 let bracketed = functional_context
416 .segment()
417 .children_where(|it: &ErasedSegment| {
418 it.is_type(SyntaxKind::FunctionContents)
419 })
420 .children_where(|it: &ErasedSegment| {
421 it.is_type(SyntaxKind::Bracketed)
422 });
423 let segments = get_children(bracketed);
424
425 fixes = shorthand_fix_list(
426 context.tables,
427 context.segment.clone(),
428 segments[1].clone(),
429 segments[0].clone(),
430 );
431
432 convert_content = Some(segments);
433 }
434 _ => {}
435 },
436 _ => {}
437 }
438
439 if convert_content
440 .filter(|convert_content| convert_content.len() > 2)
441 .is_some()
442 {
443 fixes.clear();
444 }
445
446 if cast_content
447 .filter(|cast_content| cast_content.len() > 2)
448 .is_some()
449 {
450 fixes.clear();
451 }
452
453 return vec![LintResult::new(
454 context.segment.clone().into(),
455 fixes,
456 "Used type casting style is different from the preferred type casting style."
457 .to_owned()
458 .into(),
459 None,
460 )];
461 }
462
463 _ => {}
464 }
465
466 Vec::new()
467 }
468
469 fn is_fix_compatible(&self) -> bool {
470 true
471 }
472
473 fn crawl_behaviour(&self) -> Crawler {
474 SegmentSeekerCrawler::new(
475 const { SyntaxSet::new(&[SyntaxKind::Function, SyntaxKind::CastExpression]) },
476 )
477 .into()
478 }
479}
480
481fn convert_fix_list(
482 tables: &Tables,
483 root: ErasedSegment,
484 convert_arg_1: ErasedSegment,
485 convert_arg_2: ErasedSegment,
486 later_types: Option<Segments>,
487) -> Vec<LintFix> {
488 use sqruff_lib_core::parser::segments::ErasedSegment;
489
490 let mut edits: Vec<ErasedSegment> = vec![
491 SegmentBuilder::token(
492 tables.next_id(),
493 "convert",
494 SyntaxKind::FunctionNameIdentifier,
495 )
496 .finish(),
497 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
498 convert_arg_1,
499 SegmentBuilder::token(tables.next_id(), ",", SyntaxKind::Comma).finish(),
500 SegmentBuilder::whitespace(tables.next_id(), " "),
501 convert_arg_2,
502 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
503 ];
504
505 if let Some(later_types) = later_types {
506 let pre_edits: Vec<ErasedSegment> = vec![
507 SegmentBuilder::token(
508 tables.next_id(),
509 "convert",
510 SyntaxKind::FunctionNameIdentifier,
511 )
512 .finish(),
513 SegmentBuilder::symbol(tables.next_id(), "("),
514 ];
515
516 let in_edits: Vec<ErasedSegment> = vec![
517 SegmentBuilder::symbol(tables.next_id(), ","),
518 SegmentBuilder::whitespace(tables.next_id(), " "),
519 ];
520
521 let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
522
523 for _type in later_types.base {
524 edits = chain(
525 chain(pre_edits.clone(), vec![_type]),
526 chain(in_edits.clone(), chain(edits, post_edits.clone())),
527 )
528 .collect();
529 }
530 }
531
532 vec![LintFix::replace(root, edits, None)]
533}
534
535fn cast_fix_list(
536 tables: &Tables,
537 root: ErasedSegment,
538 cast_arg_1: &[ErasedSegment],
539 cast_arg_2: ErasedSegment,
540 later_types: Option<Segments>,
541) -> Vec<LintFix> {
542 let mut edits = vec![
543 SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
544 .finish(),
545 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
546 ];
547 edits.extend_from_slice(cast_arg_1);
548 edits.extend([
549 SegmentBuilder::whitespace(tables.next_id(), " "),
550 SegmentBuilder::keyword(tables.next_id(), "as"),
551 SegmentBuilder::whitespace(tables.next_id(), " "),
552 cast_arg_2,
553 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
554 ]);
555
556 if let Some(later_types) = later_types {
557 let pre_edits: Vec<ErasedSegment> = vec![
558 SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
559 .finish(),
560 SegmentBuilder::symbol(tables.next_id(), "("),
561 ];
562
563 let in_edits: Vec<ErasedSegment> = vec![
564 SegmentBuilder::whitespace(tables.next_id(), " "),
565 SegmentBuilder::keyword(tables.next_id(), "as"),
566 SegmentBuilder::whitespace(tables.next_id(), " "),
567 ];
568
569 let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
570
571 for _type in later_types.base {
572 let mut xs = Vec::new();
573 xs.extend(pre_edits.clone());
574 xs.extend(edits);
575 xs.extend(in_edits.clone());
576 xs.push(_type);
577 xs.extend(post_edits.clone());
578 edits = xs;
579 }
580 }
581
582 vec![LintFix::replace(root, edits, None)]
583}