1use ahash::AHashMap;
2use itertools::chain;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::{ErasedSegment, SegmentBuilder, Tables};
6use sqruff_lib_core::utils::functional::segments::Segments;
7use strum_macros::{AsRefStr, EnumString};
8
9use crate::core::config::Value;
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Debug, Copy, Clone, AsRefStr, EnumString, PartialEq, Default)]
16#[strum(serialize_all = "snake_case")]
17enum TypeCastingStyle {
18 #[default]
19 Consistent,
20 Cast,
21 Convert,
22 Shorthand,
23 None,
24}
25
26#[derive(Copy, Clone)]
27struct PreviousSkipped;
28
29fn get_children(segments: Segments) -> Segments {
30 segments.children_where(|it: &ErasedSegment| {
31 !it.is_meta()
32 && !matches!(
33 it.get_type(),
34 SyntaxKind::StartBracket
35 | SyntaxKind::EndBracket
36 | SyntaxKind::Whitespace
37 | SyntaxKind::Newline
38 | SyntaxKind::CastingOperator
39 | SyntaxKind::Comma
40 | SyntaxKind::Keyword
41 )
42 })
43}
44
45fn shorthand_fix_list(
46 tables: &Tables,
47 root_segment: ErasedSegment,
48 shorthand_arg_1: ErasedSegment,
49 shorthand_arg_2: ErasedSegment,
50) -> Vec<LintFix> {
51 let mut edits = if shorthand_arg_1.get_raw_segments().len() > 1 {
52 vec![
53 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
54 shorthand_arg_1,
55 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
56 ]
57 } else {
58 vec![shorthand_arg_1]
59 };
60
61 edits.extend([
62 SegmentBuilder::token(tables.next_id(), "::", SyntaxKind::CastingOperator).finish(),
63 shorthand_arg_2,
64 ]);
65
66 vec![LintFix::replace(root_segment, edits, None)]
67}
68
69#[derive(Clone, Debug, Default)]
70pub struct RuleCV11 {
71 preferred_type_casting_style: TypeCastingStyle,
72}
73
74impl Rule for RuleCV11 {
75 fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
76 Ok(RuleCV11 {
77 preferred_type_casting_style: config["preferred_type_casting_style"]
78 .as_string()
79 .unwrap()
80 .parse()
81 .unwrap(),
82 }
83 .erased())
84 }
85
86 fn name(&self) -> &'static str {
87 "convention.casting_style"
88 }
89
90 fn description(&self) -> &'static str {
91 "Enforce consistent type casting style."
92 }
93
94 fn long_description(&self) -> &'static str {
95 r"
96**Anti-pattern**
97
98Using a mixture of `CONVERT`, `::`, and `CAST` when `preferred_type_casting_style` config is set to `consistent` (default).
99
100```sql
101SELECT
102 CONVERT(int, 1) AS bar,
103 100::int::text,
104 CAST(10 AS text) AS coo
105FROM foo;
106```
107
108**Best Practice**
109
110Use a consistent type casting style.
111
112```sql
113SELECT
114 CAST(1 AS int) AS bar,
115 CAST(CAST(100 AS int) AS text),
116 CAST(10 AS text) AS coo
117FROM foo;
118```
119"
120 }
121
122 fn groups(&self) -> &'static [RuleGroups] {
123 &[RuleGroups::All, RuleGroups::Convention]
124 }
125
126 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
127 if let Some(pos_marker) = &context.segment.get_position_marker()
132 && !pos_marker.is_literal()
133 {
134 return Vec::new();
135 }
136
137 let current_type_casting_style = if context.segment.is_type(SyntaxKind::Function) {
138 let Some(function_name) = context
139 .segment
140 .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
141 else {
142 return Vec::new();
143 };
144 if function_name.raw().eq_ignore_ascii_case("CAST") {
145 TypeCastingStyle::Cast
146 } else if function_name.raw().eq_ignore_ascii_case("CONVERT") {
147 TypeCastingStyle::Convert
148 } else {
149 TypeCastingStyle::None
150 }
151 } else if context.segment.is_type(SyntaxKind::CastExpression) {
152 TypeCastingStyle::Shorthand
153 } else {
154 TypeCastingStyle::None
155 };
156
157 let functional_context = FunctionalContext::new(context);
158 match self.preferred_type_casting_style {
159 TypeCastingStyle::Consistent => {
160 let Some(prior_type_casting_style) = context.try_get::<TypeCastingStyle>() else {
161 context.set(current_type_casting_style);
162 return Vec::new();
163 };
164 let previous_skipped = context.try_get::<PreviousSkipped>();
165
166 let mut fixes = Vec::new();
167 match prior_type_casting_style {
168 TypeCastingStyle::Cast => match current_type_casting_style {
169 TypeCastingStyle::Convert => {
170 let bracketed = functional_context
171 .segment()
172 .children_where(|it: &ErasedSegment| {
173 it.is_type(SyntaxKind::FunctionContents)
174 })
175 .children_where(|it: &ErasedSegment| {
176 it.is_type(SyntaxKind::Bracketed)
177 });
178 let convert_content = get_children(bracketed);
179 if convert_content.len() > 2 {
180 if previous_skipped.is_none() {
181 context.set(PreviousSkipped);
182 }
183 return Vec::new();
184 }
185
186 fixes = cast_fix_list(
187 context.tables,
188 context.segment.clone(),
189 &[convert_content[1].clone()],
190 convert_content[0].clone(),
191 None,
192 );
193 }
194 TypeCastingStyle::Shorthand => {
195 let expression_datatype_segment =
196 get_children(functional_context.segment());
197
198 fixes = cast_fix_list(
199 context.tables,
200 context.segment.clone(),
201 &[expression_datatype_segment[0].clone()],
202 expression_datatype_segment[1].clone(),
203 Some(Segments::from_vec(
204 expression_datatype_segment.base[2..].to_vec(),
205 None,
206 )),
207 )
208 }
209 _ => {}
210 },
211 TypeCastingStyle::Convert => match current_type_casting_style {
212 TypeCastingStyle::Cast => {
213 let bracketed = functional_context
214 .segment()
215 .children_where(|it: &ErasedSegment| {
216 it.is_type(SyntaxKind::FunctionContents)
217 })
218 .children_where(|it: &ErasedSegment| {
219 it.is_type(SyntaxKind::Bracketed)
220 });
221 let cast_content = get_children(bracketed);
222
223 if cast_content.len() > 2 {
224 return Vec::new();
225 }
226
227 fixes = convert_fix_list(
228 context.tables,
229 context.segment.clone(),
230 cast_content[1].clone(),
231 cast_content[0].clone(),
232 None,
233 );
234 }
235 TypeCastingStyle::Shorthand => {
236 let expression_datatype_segment =
237 get_children(functional_context.segment());
238
239 fixes = convert_fix_list(
240 context.tables,
241 context.segment.clone(),
242 expression_datatype_segment[1].clone(),
243 expression_datatype_segment[0].clone(),
244 Some(Segments::from_vec(
245 expression_datatype_segment.base[2..].to_vec(),
246 None,
247 )),
248 );
249 }
250 _ => (),
251 },
252 TypeCastingStyle::Shorthand => {
253 if current_type_casting_style == TypeCastingStyle::Cast {
254 let bracketed = functional_context
256 .segment()
257 .children_where(|it: &ErasedSegment| {
258 it.is_type(SyntaxKind::FunctionContents)
259 })
260 .children_where(|it: &ErasedSegment| {
261 it.is_type(SyntaxKind::Bracketed)
262 });
263 let cast_content = get_children(bracketed);
264 if cast_content.len() > 2 {
265 return Vec::new();
266 }
267
268 fixes = shorthand_fix_list(
269 context.tables,
270 context.segment.clone(),
271 cast_content[0].clone(),
272 cast_content[1].clone(),
273 );
274 } else if current_type_casting_style == TypeCastingStyle::Convert {
275 let bracketed = functional_context
276 .segment()
277 .children_where(|it: &ErasedSegment| {
278 it.is_type(SyntaxKind::FunctionContents)
279 })
280 .children_where(|it: &ErasedSegment| {
281 it.is_type(SyntaxKind::Bracketed)
282 });
283 let convert_content = get_children(bracketed);
284 if convert_content.len() > 2 {
285 return Vec::new();
286 }
287
288 fixes = shorthand_fix_list(
289 context.tables,
290 context.segment.clone(),
291 convert_content[1].clone(),
292 convert_content[0].clone(),
293 );
294 }
295 }
296 _ => {}
297 }
298
299 if prior_type_casting_style != current_type_casting_style {
300 return vec![LintResult::new(
301 context.segment.clone().into(),
302 fixes,
303 "Inconsistent type casting styles found.".to_owned().into(),
304 None,
305 )];
306 }
307 }
308 _ if current_type_casting_style != self.preferred_type_casting_style => {
309 let mut convert_content = None;
310 let mut cast_content = None;
311 let mut fixes = Vec::new();
312
313 match self.preferred_type_casting_style {
314 TypeCastingStyle::Cast => match current_type_casting_style {
315 TypeCastingStyle::Convert => {
316 let bracketed = functional_context
317 .segment()
318 .children_where(|it: &ErasedSegment| {
319 it.is_type(SyntaxKind::FunctionContents)
320 })
321 .children_where(|it: &ErasedSegment| {
322 it.is_type(SyntaxKind::Bracketed)
323 });
324 let segments = get_children(bracketed);
325 fixes = cast_fix_list(
326 context.tables,
327 context.segment.clone(),
328 &[segments[1].clone()],
329 segments[0].clone(),
330 None,
331 );
332 convert_content = Some(segments);
333 }
334 TypeCastingStyle::Shorthand => {
335 let expression_datatype_segment =
336 get_children(functional_context.segment());
337 let data_type_idx = expression_datatype_segment
338 .iter()
339 .position(|seg| seg.is_type(SyntaxKind::DataType))
340 .unwrap();
341
342 fixes = cast_fix_list(
343 context.tables,
344 context.segment.clone(),
345 &expression_datatype_segment[..data_type_idx],
346 expression_datatype_segment[data_type_idx].clone(),
347 Some(Segments::from_vec(
348 expression_datatype_segment.base[data_type_idx + 1..].to_vec(),
349 None,
350 )),
351 );
352 }
353 _ => {}
354 },
355 TypeCastingStyle::Convert => match current_type_casting_style {
356 TypeCastingStyle::Cast => {
357 let bracketed = functional_context
358 .segment()
359 .children_where(|it: &ErasedSegment| {
360 it.is_type(SyntaxKind::FunctionContents)
361 })
362 .children_where(|it: &ErasedSegment| {
363 it.is_type(SyntaxKind::Bracketed)
364 });
365 let cast_content = get_children(bracketed);
366
367 fixes = convert_fix_list(
368 context.tables,
369 context.segment.clone(),
370 cast_content[1].clone(),
371 cast_content[0].clone(),
372 None,
373 );
374 }
375 TypeCastingStyle::Shorthand => {
376 let cast_content = get_children(functional_context.segment());
377
378 fixes = convert_fix_list(
379 context.tables,
380 context.segment.clone(),
381 cast_content[1].clone(),
382 cast_content[0].clone(),
383 Some(Segments::from_vec(cast_content.base[2..].to_vec(), None)),
384 )
385 }
386 _ => {}
387 },
388 TypeCastingStyle::Shorthand => match current_type_casting_style {
389 TypeCastingStyle::Cast => {
390 let bracketed = functional_context
391 .segment()
392 .children_where(|it: &ErasedSegment| {
393 it.is_type(SyntaxKind::FunctionContents)
394 })
395 .children_where(|it: &ErasedSegment| {
396 it.is_type(SyntaxKind::Bracketed)
397 });
398 let segments = get_children(bracketed);
399
400 fixes = shorthand_fix_list(
401 context.tables,
402 context.segment.clone(),
403 segments[0].clone(),
404 segments[1].clone(),
405 );
406 cast_content = Some(segments);
407 }
408 TypeCastingStyle::Convert => {
409 let bracketed = functional_context
410 .segment()
411 .children_where(|it: &ErasedSegment| {
412 it.is_type(SyntaxKind::FunctionContents)
413 })
414 .children_where(|it: &ErasedSegment| {
415 it.is_type(SyntaxKind::Bracketed)
416 });
417 let segments = get_children(bracketed);
418
419 fixes = shorthand_fix_list(
420 context.tables,
421 context.segment.clone(),
422 segments[1].clone(),
423 segments[0].clone(),
424 );
425
426 convert_content = Some(segments);
427 }
428 _ => {}
429 },
430 _ => {}
431 }
432
433 if convert_content
434 .filter(|convert_content| convert_content.len() > 2)
435 .is_some()
436 {
437 fixes.clear();
438 }
439
440 if cast_content
441 .filter(|cast_content| cast_content.len() > 2)
442 .is_some()
443 {
444 fixes.clear();
445 }
446
447 return vec![LintResult::new(
448 context.segment.clone().into(),
449 fixes,
450 "Used type casting style is different from the preferred type casting style."
451 .to_owned()
452 .into(),
453 None,
454 )];
455 }
456
457 _ => {}
458 }
459
460 Vec::new()
461 }
462
463 fn is_fix_compatible(&self) -> bool {
464 true
465 }
466
467 fn crawl_behaviour(&self) -> Crawler {
468 SegmentSeekerCrawler::new(
469 const { SyntaxSet::new(&[SyntaxKind::Function, SyntaxKind::CastExpression]) },
470 )
471 .into()
472 }
473}
474
475fn convert_fix_list(
476 tables: &Tables,
477 root: ErasedSegment,
478 convert_arg_1: ErasedSegment,
479 convert_arg_2: ErasedSegment,
480 later_types: Option<Segments>,
481) -> Vec<LintFix> {
482 use sqruff_lib_core::parser::segments::ErasedSegment;
483
484 let mut edits: Vec<ErasedSegment> = vec![
485 SegmentBuilder::token(
486 tables.next_id(),
487 "convert",
488 SyntaxKind::FunctionNameIdentifier,
489 )
490 .finish(),
491 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
492 convert_arg_1,
493 SegmentBuilder::token(tables.next_id(), ",", SyntaxKind::Comma).finish(),
494 SegmentBuilder::whitespace(tables.next_id(), " "),
495 convert_arg_2,
496 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
497 ];
498
499 if let Some(later_types) = later_types {
500 let pre_edits: Vec<ErasedSegment> = vec![
501 SegmentBuilder::token(
502 tables.next_id(),
503 "convert",
504 SyntaxKind::FunctionNameIdentifier,
505 )
506 .finish(),
507 SegmentBuilder::symbol(tables.next_id(), "("),
508 ];
509
510 let in_edits: Vec<ErasedSegment> = vec![
511 SegmentBuilder::symbol(tables.next_id(), ","),
512 SegmentBuilder::whitespace(tables.next_id(), " "),
513 ];
514
515 let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
516
517 for _type in later_types.base {
518 edits = chain(
519 chain(pre_edits.clone(), vec![_type]),
520 chain(in_edits.clone(), chain(edits, post_edits.clone())),
521 )
522 .collect();
523 }
524 }
525
526 vec![LintFix::replace(root, edits, None)]
527}
528
529fn cast_fix_list(
530 tables: &Tables,
531 root: ErasedSegment,
532 cast_arg_1: &[ErasedSegment],
533 cast_arg_2: ErasedSegment,
534 later_types: Option<Segments>,
535) -> Vec<LintFix> {
536 let mut edits = vec![
537 SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
538 .finish(),
539 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
540 ];
541 edits.extend_from_slice(cast_arg_1);
542 edits.extend([
543 SegmentBuilder::whitespace(tables.next_id(), " "),
544 SegmentBuilder::keyword(tables.next_id(), "as"),
545 SegmentBuilder::whitespace(tables.next_id(), " "),
546 cast_arg_2,
547 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
548 ]);
549
550 if let Some(later_types) = later_types {
551 let pre_edits: Vec<ErasedSegment> = vec![
552 SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
553 .finish(),
554 SegmentBuilder::symbol(tables.next_id(), "("),
555 ];
556
557 let in_edits: Vec<ErasedSegment> = vec![
558 SegmentBuilder::whitespace(tables.next_id(), " "),
559 SegmentBuilder::keyword(tables.next_id(), "as"),
560 SegmentBuilder::whitespace(tables.next_id(), " "),
561 ];
562
563 let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
564
565 for _type in later_types.base {
566 let mut xs = Vec::new();
567 xs.extend(pre_edits.clone());
568 xs.extend(edits);
569 xs.extend(in_edits.clone());
570 xs.push(_type);
571 xs.extend(post_edits.clone());
572 edits = xs;
573 }
574 }
575
576 vec![LintFix::replace(root, edits, None)]
577}