1use std::cmp::PartialEq;
2use std::str::FromStr;
3
4use sqruff_lib_core::dialects::syntax::SyntaxKind;
5use sqruff_lib_core::helpers::capitalize;
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::{ErasedSegment, Tables};
8use strum_macros::{AsRefStr, EnumString};
9
10use super::elements::{ReflowElement, ReflowSequenceType};
11use crate::core::rules::LintResult;
12use crate::utils::reflow::depth_map::StackPositionType;
13use crate::utils::reflow::elements::ReflowPoint;
14use crate::utils::reflow::helpers::{deduce_line_indent, fixes_from_results};
15
16#[derive(Debug)]
17pub struct RebreakSpan {
18 pub(crate) target: ErasedSegment,
19 pub(crate) start_idx: usize,
20 pub(crate) end_idx: usize,
21 pub(crate) line_position: LinePosition,
22 pub(crate) strict: bool,
23}
24
25#[derive(Debug)]
26pub struct RebreakIndices {
27 _dir: i32,
28 adj_pt_idx: isize,
29 newline_pt_idx: isize,
30 pre_code_pt_idx: isize,
31}
32
33impl RebreakIndices {
34 fn from_elements(elements: &ReflowSequenceType, start_idx: usize, dir: i32) -> Option<Self> {
35 assert!(dir == 1 || dir == -1);
36 let limit = if dir == -1 { 0 } else { elements.len() };
37 let adj_point_idx = start_idx as isize + dir as isize;
38
39 if adj_point_idx < 0 || adj_point_idx >= elements.len() as isize {
40 return None;
41 }
42
43 let mut newline_point_idx = adj_point_idx;
44 while (dir == 1 && newline_point_idx < limit as isize)
45 || (dir == -1 && newline_point_idx >= 0)
46 {
47 let adjacent_idx = newline_point_idx + dir as isize;
52 if adjacent_idx < 0 || adjacent_idx >= elements.len() as isize {
53 break;
54 }
55 if elements[newline_point_idx as usize]
56 .class_types()
57 .contains(SyntaxKind::Newline)
58 || elements[adjacent_idx as usize]
59 .segments()
60 .iter()
61 .any(|seg| seg.is_code())
62 {
63 break;
64 }
65 newline_point_idx += 2 * dir as isize;
66 }
67
68 if newline_point_idx < 0 || newline_point_idx >= elements.len() as isize {
71 newline_point_idx = adj_point_idx;
72 }
73
74 let mut pre_code_point_idx = newline_point_idx;
75 while (dir == 1 && pre_code_point_idx < limit as isize)
76 || (dir == -1 && pre_code_point_idx >= 0)
77 {
78 let adjacent_idx = pre_code_point_idx + dir as isize;
80 if adjacent_idx < 0 || adjacent_idx >= elements.len() as isize {
81 break;
82 }
83 if elements[adjacent_idx as usize]
84 .segments()
85 .iter()
86 .any(|seg| seg.is_code())
87 {
88 break;
89 }
90 pre_code_point_idx += 2 * dir as isize;
91 }
92
93 if pre_code_point_idx < 0 || pre_code_point_idx >= elements.len() as isize {
95 pre_code_point_idx = newline_point_idx;
96 }
97
98 RebreakIndices {
99 _dir: dir,
100 adj_pt_idx: adj_point_idx,
101 newline_pt_idx: newline_point_idx,
102 pre_code_pt_idx: pre_code_point_idx,
103 }
104 .into()
105 }
106}
107
108#[derive(Debug)]
109pub struct RebreakLocation {
110 target: ErasedSegment,
111 prev: RebreakIndices,
112 next: RebreakIndices,
113 line_position: LinePosition,
114 strict: bool,
115}
116
117#[derive(Debug, PartialEq, Clone, Copy, AsRefStr, EnumString)]
118#[strum(serialize_all = "lowercase")]
119pub enum LinePosition {
120 Leading,
121 Trailing,
122 Alone,
123 Strict,
124}
125
126impl RebreakLocation {
127 pub fn from_span(span: RebreakSpan, elements: &ReflowSequenceType) -> Option<Self> {
129 Self {
130 target: span.target,
131 prev: RebreakIndices::from_elements(elements, span.start_idx, -1)?,
132 next: RebreakIndices::from_elements(elements, span.end_idx, 1)?,
133 line_position: span.line_position,
134 strict: span.strict,
135 }
136 .into()
137 }
138
139 fn has_inappropriate_newlines(&self, elements: &ReflowSequenceType, strict: bool) -> bool {
140 let n_prev_newlines = elements[self.prev.newline_pt_idx as usize].num_newlines();
141 let n_next_newlines = elements[self.next.newline_pt_idx as usize].num_newlines();
142
143 let newlines_on_neither_side = n_prev_newlines + n_next_newlines == 0;
144 let newlines_on_both_sides = n_prev_newlines > 0 && n_next_newlines > 0;
145
146 (newlines_on_neither_side && !strict) || newlines_on_both_sides
147 }
148
149 fn pretty_target_name(&self) -> String {
150 format!("{} {}", self.target.get_type().as_str(), self.target.raw())
151 }
152}
153
154pub fn identify_rebreak_spans(
155 element_buffer: &ReflowSequenceType,
156 root_segment: &ErasedSegment,
157) -> Vec<RebreakSpan> {
158 let mut spans = Vec::new();
159
160 for (idx, elem) in element_buffer
161 .iter()
162 .enumerate()
163 .take(element_buffer.len() - 2)
164 .skip(2)
165 {
166 let ReflowElement::Block(block) = elem else {
167 continue;
168 };
169
170 if let Some(original_line_position) = block.line_position() {
171 let line_position = original_line_position.first().unwrap();
172 spans.push(RebreakSpan {
173 target: elem.segments().first().cloned().unwrap(),
174 start_idx: idx,
175 end_idx: idx,
176 line_position: *line_position,
177 strict: original_line_position.last() == Some(&LinePosition::Strict),
178 });
179 }
180
181 for key in block.line_position_configs().keys() {
182 let mut final_idx = None;
183 if block.depth_info().stack_positions[key].idx != 0 {
184 continue;
185 }
186
187 for (end_idx, end_elem) in element_buffer.iter().enumerate().skip(idx) {
188 let ReflowElement::Block(end_block) = end_elem else {
189 continue;
190 };
191
192 if !end_block.depth_info().stack_positions.contains_key(key) {
193 if final_idx.is_none() {
195 final_idx = (end_idx - 2).into();
196 }
197 break;
198 } else if matches!(
199 end_block.depth_info().stack_positions[key].type_,
200 Some(StackPositionType::End) | Some(StackPositionType::Solo)
201 ) {
202 final_idx = end_idx.into();
206 }
207 }
208
209 if let Some(final_idx) = final_idx {
210 let target_depth = block
211 .depth_info()
212 .stack_hashes
213 .iter()
214 .position(|it| it == key)
215 .unwrap();
216 let target = root_segment.path_to(&element_buffer[idx].segments()[0])[target_depth]
217 .segment
218 .clone();
219
220 let line_position_configs = block.line_position_configs()[key]
221 .split(':')
222 .next()
223 .unwrap();
224 let line_position = LinePosition::from_str(line_position_configs).unwrap();
225
226 spans.push(RebreakSpan {
227 target,
228 start_idx: idx,
229 end_idx: final_idx,
230 line_position,
231 strict: block.line_position_configs()[key].ends_with("strict"),
232 });
233 }
234 }
235 }
236
237 spans
238}
239
240pub fn rebreak_sequence(
241 tables: &Tables,
242 elements: ReflowSequenceType,
243 root_segment: &ErasedSegment,
244) -> (ReflowSequenceType, Vec<LintResult>) {
245 let mut lint_results = Vec::new();
246 let mut fixes = Vec::new();
247 let mut elem_buff = elements.clone();
248
249 let spans = identify_rebreak_spans(&elem_buff, root_segment);
258
259 let mut locations = Vec::new();
260 for span in spans {
261 if let Some(loc) = RebreakLocation::from_span(span, &elements) {
262 locations.push(loc);
263 }
264 }
265
266 for loc in locations {
268 if loc.has_inappropriate_newlines(&elements, loc.strict) {
269 continue;
270 }
271
272 let prev_point = elem_buff[loc.prev.adj_pt_idx as usize]
278 .as_point()
279 .unwrap()
280 .clone();
281 let next_point = elem_buff[loc.next.adj_pt_idx as usize]
282 .as_point()
283 .unwrap()
284 .clone();
285
286 let new_results = if loc.line_position == LinePosition::Leading {
288 if elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() != 0 {
289 continue;
291 }
292
293 let pretty_name = loc.pretty_target_name();
295 let _desc = if loc.strict {
296 format!(
297 "{} should always start a new line.",
298 capitalize(&pretty_name)
299 )
300 } else {
301 format!("Found trailing {pretty_name}. Expected only leading near line breaks.")
302 };
303
304 if loc.next.adj_pt_idx == loc.next.pre_code_pt_idx
305 && elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 1
306 {
307 let desired_indent = next_point.get_indent().unwrap_or_default();
312
313 let (new_results, prev_point) = prev_point.indent_to(
314 tables,
315 &desired_indent,
316 None,
317 loc.target.clone().into(),
318 None,
319 None,
320 );
321
322 let (new_results, next_point) = next_point.respace_point(
323 tables,
324 elem_buff[loc.next.adj_pt_idx as usize - 1].as_block(),
325 elem_buff[loc.next.adj_pt_idx as usize + 1].as_block(),
326 root_segment,
327 new_results,
328 true,
329 "before",
330 );
331
332 elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
334 elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
335
336 new_results
337 } else {
338 fixes.push(LintFix::delete(loc.target.clone()));
339 for seg in elem_buff[loc.prev.adj_pt_idx as usize].segments() {
340 fixes.push(LintFix::delete(seg.clone()));
341 }
342
343 let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
344 tables,
345 elem_buff[(loc.next.adj_pt_idx - 1) as usize].as_block(),
346 elem_buff[(loc.next.pre_code_pt_idx + 1) as usize].as_block(),
347 root_segment,
348 Vec::new(),
349 false,
350 "after",
351 );
352
353 let mut create_anchor = None;
354 for i in 0..loc.next.pre_code_pt_idx {
355 let idx = loc.next.pre_code_pt_idx - i;
356 if let Some(elem) = elem_buff.get(idx as usize)
357 && let Some(segments) = elem.segments().last()
358 {
359 create_anchor = Some(segments.clone());
360 break;
361 }
362 }
363
364 if create_anchor.is_none() {
365 panic!("Could not find anchor for creation.");
366 }
367
368 fixes.push(LintFix::create_after(
369 create_anchor.unwrap(),
370 vec![loc.target.clone()],
371 None,
372 ));
373
374 rearrange_and_insert(&mut elem_buff, &loc, new_point);
375
376 new_results
377 }
378 } else if loc.line_position == LinePosition::Trailing {
379 if elem_buff[loc.next.newline_pt_idx as usize].num_newlines() != 0 {
380 continue;
381 }
382
383 let pretty_name = loc.pretty_target_name();
384 let _desc = if loc.strict {
385 format!(
386 "{} should always be at the end of a line.",
387 capitalize(&pretty_name)
388 )
389 } else {
390 format!("Found leading {pretty_name}. Expected only trailing near line breaks.")
391 };
392
393 if loc.prev.adj_pt_idx == loc.prev.pre_code_pt_idx
394 && elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() == 1
395 {
396 let (new_results, next_point) = next_point.indent_to(
397 tables,
398 prev_point.get_indent().as_deref().unwrap_or_default(),
399 Some(loc.target.clone()),
400 None,
401 None,
402 None,
403 );
404
405 let (new_results, prev_point) = prev_point.respace_point(
406 tables,
407 elem_buff[loc.prev.adj_pt_idx as usize - 1].as_block(),
408 elem_buff[loc.prev.adj_pt_idx as usize + 1].as_block(),
409 root_segment,
410 new_results,
411 true,
412 "before",
413 );
414
415 elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
417 elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
418
419 new_results
420 } else {
421 fixes.push(LintFix::delete(loc.target.clone()));
422 for seg in elem_buff[loc.next.adj_pt_idx as usize].segments() {
423 fixes.push(LintFix::delete(seg.clone()));
424 }
425
426 let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
427 tables,
428 elem_buff[(loc.prev.pre_code_pt_idx - 1) as usize].as_block(),
429 elem_buff[(loc.prev.adj_pt_idx + 1) as usize].as_block(),
430 root_segment,
431 Vec::new(),
432 false,
433 "before",
434 );
435
436 fixes.push(LintFix::create_before(
437 elem_buff[loc.prev.pre_code_pt_idx as usize].segments()[0].clone(),
438 vec![loc.target.clone()],
439 ));
440
441 reorder_and_insert(&mut elem_buff, &loc, new_point);
442
443 new_results
444 }
445 } else if loc.line_position == LinePosition::Alone {
446 let mut new_results = Vec::new();
447
448 let needs_next_newline =
449 elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 0;
450 let needs_prev_newline = elem_buff[loc.prev.adj_pt_idx as usize].num_newlines() == 0;
451
452 let next_code_idx = (loc.next.pre_code_pt_idx + 1) as usize;
456 let next_is_statement_terminator = next_code_idx < elem_buff.len()
457 && elem_buff[next_code_idx]
458 .segments()
459 .iter()
460 .any(|seg| seg.get_type() == SyntaxKind::StatementTerminator);
461
462 let skip_next_newline = needs_next_newline && next_is_statement_terminator;
463
464 if (!needs_next_newline || skip_next_newline) && !needs_prev_newline {
465 continue;
466 }
467
468 if needs_next_newline && !skip_next_newline {
469 let (results, next_point) = next_point.indent_to(
470 tables,
471 &deduce_line_indent(
472 loc.target.get_raw_segments().last().unwrap(),
473 root_segment,
474 ),
475 loc.target.clone().into(),
476 None,
477 None,
478 None,
479 );
480
481 new_results = results;
482 elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
483 }
484
485 if needs_prev_newline {
486 let (results, prev_point) = prev_point.indent_to(
487 tables,
488 &deduce_line_indent(
489 loc.target.get_raw_segments().first().unwrap(),
490 root_segment,
491 ),
492 None,
493 loc.target.clone().into(),
494 None,
495 None,
496 );
497
498 new_results = results;
499 elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
500 }
501
502 new_results
503 } else {
504 unimplemented!(
505 "Unexpected line_position config: {}",
506 loc.line_position.as_ref()
507 )
508 };
509
510 let fixes = fixes_from_results(new_results.into_iter())
511 .chain(std::mem::take(&mut fixes))
512 .collect();
513 lint_results.push(LintResult::new(
514 loc.target.clone().into(),
515 fixes,
516 None,
517 None,
518 ));
519 }
520
521 (elem_buff, lint_results)
522}
523
524fn rearrange_and_insert(
525 elem_buff: &mut Vec<ReflowElement>,
526 loc: &RebreakLocation,
527 new_point: ReflowPoint,
528) {
529 let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
530
531 new_buff.extend_from_slice(&elem_buff[..loc.prev.adj_pt_idx as usize]);
533
534 new_buff.extend_from_slice(
536 &elem_buff[loc.next.adj_pt_idx as usize..=loc.next.pre_code_pt_idx as usize],
537 );
538
539 if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
542 new_buff.extend_from_slice(
543 &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
544 );
545 }
546
547 new_buff.push(new_point.into());
549
550 if loc.next.pre_code_pt_idx as usize + 1 < elem_buff.len() {
552 new_buff.extend_from_slice(&elem_buff[loc.next.pre_code_pt_idx as usize + 1..]);
553 }
554
555 *elem_buff = new_buff;
557}
558
559fn reorder_and_insert(
560 elem_buff: &mut Vec<ReflowElement>,
561 loc: &RebreakLocation,
562 new_point: ReflowPoint,
563) {
564 let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
565
566 new_buff.extend_from_slice(&elem_buff[..loc.prev.pre_code_pt_idx as usize]);
568
569 new_buff.push(new_point.into());
571
572 if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
575 new_buff.extend_from_slice(
576 &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
577 );
578 }
579
580 new_buff.extend_from_slice(
583 &elem_buff[loc.prev.pre_code_pt_idx as usize..=loc.prev.adj_pt_idx as usize],
584 );
585
586 if loc.next.adj_pt_idx as usize + 1 < elem_buff.len() {
588 new_buff.extend_from_slice(&elem_buff[loc.next.adj_pt_idx as usize + 1..]);
589 }
590
591 *elem_buff = new_buff;
593}
594
595#[cfg(test)]
596mod tests {
597 use sqruff_lib::core::test_functions::parse_ansi_string;
598 use sqruff_lib_core::helpers::enter_panic;
599 use sqruff_lib_core::parser::segments::Tables;
600
601 use crate::utils::reflow::sequence::{ReflowSequence, TargetSide};
602
603 #[test]
604 fn test_reflow_sequence_rebreak_root() {
605 let cases = [
606 ("select 1", "select 1"),
608 ("select 1\n+2", "select 1\n+2"),
610 ("select 1+\n2", "select 1\n+ 2"), ("select\n 1 +\n 2", "select\n 1\n + 2"),
612 (
613 "select\n 1 +\n -- comment\n 2",
614 "select\n 1\n -- comment\n + 2",
615 ),
616 ("select a,b", "select a,b"),
618 ("select a\n,b", "select a,\nb"),
619 ("select\n a\n , b", "select\n a,\n b"),
620 ("select\n a\n , b", "select\n a,\n b"),
621 ("select\n a\n , b", "select\n a,\n b"),
622 (
623 "select\n a\n -- comment\n , b",
624 "select\n a,\n -- comment\n b",
625 ),
626 ];
627
628 let tables = Tables::default();
629 for (raw_sql_in, raw_sql_out) in cases {
630 let _panic = enter_panic(format!("{raw_sql_in:?}"));
631
632 let root = parse_ansi_string(raw_sql_in);
633 let config = <_>::default();
634 let seq = ReflowSequence::from_root(&root, &config);
635 let new_seq = seq.rebreak(&tables);
636
637 assert_eq!(new_seq.raw(), raw_sql_out);
638 }
639 }
640
641 #[test]
642 fn test_reflow_sequence_rebreak_target() {
643 let cases = [
644 ("select 1+\n(2+3)", 4, "1+\n(", "1\n+ ("),
645 ("select a,\n(b+c)", 4, "a,\n(", "a,\n("),
646 ("select a\n , (b+c)", 6, "a\n , (", "a,\n ("),
647 ("select a,\n(b+c)", 6, ",\n(b", ",\n(b"),
650 ("select a<=b", 4, "a<=", "a<="),
652 ];
653
654 let tables = Tables::default();
655 for (raw_sql_in, target_idx, seq_sql_in, seq_sql_out) in cases {
656 let root = parse_ansi_string(raw_sql_in);
657 let target = &root.get_raw_segments()[target_idx];
658 let config = <_>::default();
659 let seq = ReflowSequence::from_around_target(target, &root, TargetSide::Both, &config);
660
661 assert_eq!(seq.raw(), seq_sql_in);
662
663 let new_seq = seq.rebreak(&tables);
664 assert_eq!(new_seq.raw(), seq_sql_out);
665 }
666 }
667}