Skip to main content

sqruff_lib/utils/reflow/
rebreak.rs

1use std::cmp::PartialEq;
2use std::str::FromStr;
3
4use sqruff_lib_core::dialects::syntax::SyntaxKind;
5use sqruff_lib_core::helpers::capitalize;
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::{ErasedSegment, Tables};
8use strum_macros::{AsRefStr, EnumString};
9
10use super::elements::{ReflowElement, ReflowSequenceType};
11use crate::core::rules::LintResult;
12use crate::utils::reflow::depth_map::StackPositionType;
13use crate::utils::reflow::elements::ReflowPoint;
14use crate::utils::reflow::helpers::{deduce_line_indent, fixes_from_results};
15
16#[derive(Debug)]
17pub struct RebreakSpan {
18    pub(crate) target: ErasedSegment,
19    pub(crate) start_idx: usize,
20    pub(crate) end_idx: usize,
21    pub(crate) line_position: LinePosition,
22    pub(crate) strict: bool,
23}
24
25#[derive(Debug)]
26pub struct RebreakIndices {
27    _dir: i32,
28    adj_pt_idx: isize,
29    newline_pt_idx: isize,
30    pre_code_pt_idx: isize,
31}
32
33impl RebreakIndices {
34    fn from_elements(elements: &ReflowSequenceType, start_idx: usize, dir: i32) -> Option<Self> {
35        assert!(dir == 1 || dir == -1);
36        let limit = if dir == -1 { 0 } else { elements.len() };
37        let adj_point_idx = start_idx as isize + dir as isize;
38
39        if adj_point_idx < 0 || adj_point_idx >= elements.len() as isize {
40            return None;
41        }
42
43        let mut newline_point_idx = adj_point_idx;
44        while (dir == 1 && newline_point_idx < limit as isize)
45            || (dir == -1 && newline_point_idx >= 0)
46        {
47            // Check bounds for the adjacent element access. When traversing
48            // backward (dir=-1) and reaching index 0, (idx + dir) would be -1
49            // which wraps to usize::MAX causing a panic. Break here as we've
50            // reached the boundary - there's nothing further in this direction.
51            let adjacent_idx = newline_point_idx + dir as isize;
52            if adjacent_idx < 0 || adjacent_idx >= elements.len() as isize {
53                break;
54            }
55            if elements[newline_point_idx as usize]
56                .class_types()
57                .contains(SyntaxKind::Newline)
58                || elements[adjacent_idx as usize]
59                    .segments()
60                    .iter()
61                    .any(|seg| seg.is_code())
62            {
63                break;
64            }
65            newline_point_idx += 2 * dir as isize;
66        }
67
68        // Clamp to the adjacent point if scanning went out of bounds
69        // (no newline or code found in this direction).
70        if newline_point_idx < 0 || newline_point_idx >= elements.len() as isize {
71            newline_point_idx = adj_point_idx;
72        }
73
74        let mut pre_code_point_idx = newline_point_idx;
75        while (dir == 1 && pre_code_point_idx < limit as isize)
76            || (dir == -1 && pre_code_point_idx >= 0)
77        {
78            // Same bounds check as above for the adjacent element access.
79            let adjacent_idx = pre_code_point_idx + dir as isize;
80            if adjacent_idx < 0 || adjacent_idx >= elements.len() as isize {
81                break;
82            }
83            if elements[adjacent_idx as usize]
84                .segments()
85                .iter()
86                .any(|seg| seg.is_code())
87            {
88                break;
89            }
90            pre_code_point_idx += 2 * dir as isize;
91        }
92
93        // Clamp to the newline point if scanning went out of bounds.
94        if pre_code_point_idx < 0 || pre_code_point_idx >= elements.len() as isize {
95            pre_code_point_idx = newline_point_idx;
96        }
97
98        RebreakIndices {
99            _dir: dir,
100            adj_pt_idx: adj_point_idx,
101            newline_pt_idx: newline_point_idx,
102            pre_code_pt_idx: pre_code_point_idx,
103        }
104        .into()
105    }
106}
107
108#[derive(Debug)]
109pub struct RebreakLocation {
110    target: ErasedSegment,
111    prev: RebreakIndices,
112    next: RebreakIndices,
113    line_position: LinePosition,
114    strict: bool,
115}
116
117#[derive(Debug, PartialEq, Clone, Copy, AsRefStr, EnumString)]
118#[strum(serialize_all = "lowercase")]
119pub enum LinePosition {
120    Leading,
121    Trailing,
122    Alone,
123    Strict,
124}
125
126impl RebreakLocation {
127    /// Expand a span to a location.
128    pub fn from_span(span: RebreakSpan, elements: &ReflowSequenceType) -> Option<Self> {
129        Self {
130            target: span.target,
131            prev: RebreakIndices::from_elements(elements, span.start_idx, -1)?,
132            next: RebreakIndices::from_elements(elements, span.end_idx, 1)?,
133            line_position: span.line_position,
134            strict: span.strict,
135        }
136        .into()
137    }
138
139    fn has_inappropriate_newlines(&self, elements: &ReflowSequenceType, strict: bool) -> bool {
140        let n_prev_newlines = elements[self.prev.newline_pt_idx as usize].num_newlines();
141        let n_next_newlines = elements[self.next.newline_pt_idx as usize].num_newlines();
142
143        let newlines_on_neither_side = n_prev_newlines + n_next_newlines == 0;
144        let newlines_on_both_sides = n_prev_newlines > 0 && n_next_newlines > 0;
145
146        (newlines_on_neither_side && !strict) || newlines_on_both_sides
147    }
148
149    fn pretty_target_name(&self) -> String {
150        format!("{} {}", self.target.get_type().as_str(), self.target.raw())
151    }
152}
153
154pub fn identify_rebreak_spans(
155    element_buffer: &ReflowSequenceType,
156    root_segment: &ErasedSegment,
157) -> Vec<RebreakSpan> {
158    let mut spans = Vec::new();
159
160    for (idx, elem) in element_buffer
161        .iter()
162        .enumerate()
163        .take(element_buffer.len() - 2)
164        .skip(2)
165    {
166        let ReflowElement::Block(block) = elem else {
167            continue;
168        };
169
170        if let Some(original_line_position) = block.line_position() {
171            let line_position = original_line_position.first().unwrap();
172            spans.push(RebreakSpan {
173                target: elem.segments().first().cloned().unwrap(),
174                start_idx: idx,
175                end_idx: idx,
176                line_position: *line_position,
177                strict: original_line_position.last() == Some(&LinePosition::Strict),
178            });
179        }
180
181        for key in block.line_position_configs().keys() {
182            let mut final_idx = None;
183            if block.depth_info().stack_positions[key].idx != 0 {
184                continue;
185            }
186
187            for (end_idx, end_elem) in element_buffer.iter().enumerate().skip(idx) {
188                let ReflowElement::Block(end_block) = end_elem else {
189                    continue;
190                };
191
192                if !end_block.depth_info().stack_positions.contains_key(key) {
193                    // Left the scope. The last block inside is two positions back.
194                    if final_idx.is_none() {
195                        final_idx = (end_idx - 2).into();
196                    }
197                    break;
198                } else if matches!(
199                    end_block.depth_info().stack_positions[key].type_,
200                    Some(StackPositionType::End) | Some(StackPositionType::Solo)
201                ) {
202                    // Track the latest End/Solo block but keep scanning,
203                    // because multiple blocks within the last child all
204                    // have End type at the parent level.
205                    final_idx = end_idx.into();
206                }
207            }
208
209            if let Some(final_idx) = final_idx {
210                let target_depth = block
211                    .depth_info()
212                    .stack_hashes
213                    .iter()
214                    .position(|it| it == key)
215                    .unwrap();
216                let target = root_segment.path_to(&element_buffer[idx].segments()[0])[target_depth]
217                    .segment
218                    .clone();
219
220                let line_position_configs = block.line_position_configs()[key]
221                    .split(':')
222                    .next()
223                    .unwrap();
224                let line_position = LinePosition::from_str(line_position_configs).unwrap();
225
226                spans.push(RebreakSpan {
227                    target,
228                    start_idx: idx,
229                    end_idx: final_idx,
230                    line_position,
231                    strict: block.line_position_configs()[key].ends_with("strict"),
232                });
233            }
234        }
235    }
236
237    spans
238}
239
240pub fn rebreak_sequence(
241    tables: &Tables,
242    elements: ReflowSequenceType,
243    root_segment: &ErasedSegment,
244) -> (ReflowSequenceType, Vec<LintResult>) {
245    let mut lint_results = Vec::new();
246    let mut fixes = Vec::new();
247    let mut elem_buff = elements.clone();
248
249    // Given a sequence we should identify the objects which
250    // make sense to rebreak. That includes any raws with config,
251    // but also and parent segments which have config and we can
252    // find both ends for. Given those spans, we then need to find
253    // the points either side of them and then the blocks either
254    // side to respace them at the same time.
255
256    // 1. First find appropriate spans.
257    let spans = identify_rebreak_spans(&elem_buff, root_segment);
258
259    let mut locations = Vec::new();
260    for span in spans {
261        if let Some(loc) = RebreakLocation::from_span(span, &elements) {
262            locations.push(loc);
263        }
264    }
265
266    // Handle each span:
267    for loc in locations {
268        if loc.has_inappropriate_newlines(&elements, loc.strict) {
269            continue;
270        }
271
272        // if loc.has_templated_newline(elem_buff) {
273        //     continue;
274        // }
275
276        // Points and blocks either side are just offsets from the indices.
277        let prev_point = elem_buff[loc.prev.adj_pt_idx as usize]
278            .as_point()
279            .unwrap()
280            .clone();
281        let next_point = elem_buff[loc.next.adj_pt_idx as usize]
282            .as_point()
283            .unwrap()
284            .clone();
285
286        // So we know we have a preference, is it ok?
287        let new_results = if loc.line_position == LinePosition::Leading {
288            if elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() != 0 {
289                // We're good. It's already leading.
290                continue;
291            }
292
293            // Generate the text for any issues.
294            let pretty_name = loc.pretty_target_name();
295            let _desc = if loc.strict {
296                format!(
297                    "{} should always start a new line.",
298                    capitalize(&pretty_name)
299                )
300            } else {
301                format!("Found trailing {pretty_name}. Expected only leading near line breaks.")
302            };
303
304            if loc.next.adj_pt_idx == loc.next.pre_code_pt_idx
305                && elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 1
306            {
307                // Simple case. No comments.
308                // Strip newlines from the next point.
309                // Apply the indent to the previous point.
310
311                let desired_indent = next_point.get_indent().unwrap_or_default();
312
313                let (new_results, prev_point) = prev_point.indent_to(
314                    tables,
315                    &desired_indent,
316                    None,
317                    loc.target.clone().into(),
318                    None,
319                    None,
320                );
321
322                let (new_results, next_point) = next_point.respace_point(
323                    tables,
324                    elem_buff[loc.next.adj_pt_idx as usize - 1].as_block(),
325                    elem_buff[loc.next.adj_pt_idx as usize + 1].as_block(),
326                    root_segment,
327                    new_results,
328                    true,
329                    "before",
330                );
331
332                // Update the points in the buffer
333                elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
334                elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
335
336                new_results
337            } else {
338                fixes.push(LintFix::delete(loc.target.clone()));
339                for seg in elem_buff[loc.prev.adj_pt_idx as usize].segments() {
340                    fixes.push(LintFix::delete(seg.clone()));
341                }
342
343                let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
344                    tables,
345                    elem_buff[(loc.next.adj_pt_idx - 1) as usize].as_block(),
346                    elem_buff[(loc.next.pre_code_pt_idx + 1) as usize].as_block(),
347                    root_segment,
348                    Vec::new(),
349                    false,
350                    "after",
351                );
352
353                let mut create_anchor = None;
354                for i in 0..loc.next.pre_code_pt_idx {
355                    let idx = loc.next.pre_code_pt_idx - i;
356                    if let Some(elem) = elem_buff.get(idx as usize)
357                        && let Some(segments) = elem.segments().last()
358                    {
359                        create_anchor = Some(segments.clone());
360                        break;
361                    }
362                }
363
364                if create_anchor.is_none() {
365                    panic!("Could not find anchor for creation.");
366                }
367
368                fixes.push(LintFix::create_after(
369                    create_anchor.unwrap(),
370                    vec![loc.target.clone()],
371                    None,
372                ));
373
374                rearrange_and_insert(&mut elem_buff, &loc, new_point);
375
376                new_results
377            }
378        } else if loc.line_position == LinePosition::Trailing {
379            if elem_buff[loc.next.newline_pt_idx as usize].num_newlines() != 0 {
380                continue;
381            }
382
383            let pretty_name = loc.pretty_target_name();
384            let _desc = if loc.strict {
385                format!(
386                    "{} should always be at the end of a line.",
387                    capitalize(&pretty_name)
388                )
389            } else {
390                format!("Found leading {pretty_name}. Expected only trailing near line breaks.")
391            };
392
393            if loc.prev.adj_pt_idx == loc.prev.pre_code_pt_idx
394                && elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() == 1
395            {
396                let (new_results, next_point) = next_point.indent_to(
397                    tables,
398                    prev_point.get_indent().as_deref().unwrap_or_default(),
399                    Some(loc.target.clone()),
400                    None,
401                    None,
402                    None,
403                );
404
405                let (new_results, prev_point) = prev_point.respace_point(
406                    tables,
407                    elem_buff[loc.prev.adj_pt_idx as usize - 1].as_block(),
408                    elem_buff[loc.prev.adj_pt_idx as usize + 1].as_block(),
409                    root_segment,
410                    new_results,
411                    true,
412                    "before",
413                );
414
415                // Update the points in the buffer
416                elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
417                elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
418
419                new_results
420            } else {
421                fixes.push(LintFix::delete(loc.target.clone()));
422                for seg in elem_buff[loc.next.adj_pt_idx as usize].segments() {
423                    fixes.push(LintFix::delete(seg.clone()));
424                }
425
426                let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
427                    tables,
428                    elem_buff[(loc.prev.pre_code_pt_idx - 1) as usize].as_block(),
429                    elem_buff[(loc.prev.adj_pt_idx + 1) as usize].as_block(),
430                    root_segment,
431                    Vec::new(),
432                    false,
433                    "before",
434                );
435
436                fixes.push(LintFix::create_before(
437                    elem_buff[loc.prev.pre_code_pt_idx as usize].segments()[0].clone(),
438                    vec![loc.target.clone()],
439                ));
440
441                reorder_and_insert(&mut elem_buff, &loc, new_point);
442
443                new_results
444            }
445        } else if loc.line_position == LinePosition::Alone {
446            let mut new_results = Vec::new();
447
448            let needs_next_newline =
449                elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 0;
450            let needs_prev_newline = elem_buff[loc.prev.adj_pt_idx as usize].num_newlines() == 0;
451
452            // Don't add a newline before a statement terminator (semicolon).
453            // The terminator should stay on the same line as the preceding
454            // clause content (e.g. `WHERE a = 1;` not `WHERE a = 1\n;`).
455            let next_code_idx = (loc.next.pre_code_pt_idx + 1) as usize;
456            let next_is_statement_terminator = next_code_idx < elem_buff.len()
457                && elem_buff[next_code_idx]
458                    .segments()
459                    .iter()
460                    .any(|seg| seg.get_type() == SyntaxKind::StatementTerminator);
461
462            let skip_next_newline = needs_next_newline && next_is_statement_terminator;
463
464            if (!needs_next_newline || skip_next_newline) && !needs_prev_newline {
465                continue;
466            }
467
468            if needs_next_newline && !skip_next_newline {
469                let (results, next_point) = next_point.indent_to(
470                    tables,
471                    &deduce_line_indent(
472                        loc.target.get_raw_segments().last().unwrap(),
473                        root_segment,
474                    ),
475                    loc.target.clone().into(),
476                    None,
477                    None,
478                    None,
479                );
480
481                new_results = results;
482                elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
483            }
484
485            if needs_prev_newline {
486                let (results, prev_point) = prev_point.indent_to(
487                    tables,
488                    &deduce_line_indent(
489                        loc.target.get_raw_segments().first().unwrap(),
490                        root_segment,
491                    ),
492                    None,
493                    loc.target.clone().into(),
494                    None,
495                    None,
496                );
497
498                new_results = results;
499                elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
500            }
501
502            new_results
503        } else {
504            unimplemented!(
505                "Unexpected line_position config: {}",
506                loc.line_position.as_ref()
507            )
508        };
509
510        let fixes = fixes_from_results(new_results.into_iter())
511            .chain(std::mem::take(&mut fixes))
512            .collect();
513        lint_results.push(LintResult::new(
514            loc.target.clone().into(),
515            fixes,
516            None,
517            None,
518        ));
519    }
520
521    (elem_buff, lint_results)
522}
523
524fn rearrange_and_insert(
525    elem_buff: &mut Vec<ReflowElement>,
526    loc: &RebreakLocation,
527    new_point: ReflowPoint,
528) {
529    let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
530
531    // First segment: up to loc.prev.adj_pt_idx (exclusive)
532    new_buff.extend_from_slice(&elem_buff[..loc.prev.adj_pt_idx as usize]);
533
534    // Second segment: loc.next.adj_pt_idx to loc.next.pre_code_pt_idx (inclusive)
535    new_buff.extend_from_slice(
536        &elem_buff[loc.next.adj_pt_idx as usize..=loc.next.pre_code_pt_idx as usize],
537    );
538
539    // Third segment: loc.prev.adj_pt_idx + 1 to loc.next.adj_pt_idx (exclusive, the
540    // target)
541    if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
542        new_buff.extend_from_slice(
543            &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
544        );
545    }
546
547    // Insert new_point here
548    new_buff.push(new_point.into());
549
550    // Last segment: after loc.next.pre_code_pt_idx
551    if loc.next.pre_code_pt_idx as usize + 1 < elem_buff.len() {
552        new_buff.extend_from_slice(&elem_buff[loc.next.pre_code_pt_idx as usize + 1..]);
553    }
554
555    // Replace old buffer with the new one
556    *elem_buff = new_buff;
557}
558
559fn reorder_and_insert(
560    elem_buff: &mut Vec<ReflowElement>,
561    loc: &RebreakLocation,
562    new_point: ReflowPoint,
563) {
564    let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
565
566    // First segment: up to loc.prev.pre_code_pt_idx (exclusive)
567    new_buff.extend_from_slice(&elem_buff[..loc.prev.pre_code_pt_idx as usize]);
568
569    // Insert new_point here
570    new_buff.push(new_point.into());
571
572    // Second segment: loc.prev.adj_pt_idx + 1 to loc.next.adj_pt_idx (exclusive,
573    // the target)
574    if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
575        new_buff.extend_from_slice(
576            &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
577        );
578    }
579
580    // Third segment: loc.prev.pre_code_pt_idx to loc.prev.adj_pt_idx + 1
581    // (inclusive)
582    new_buff.extend_from_slice(
583        &elem_buff[loc.prev.pre_code_pt_idx as usize..=loc.prev.adj_pt_idx as usize],
584    );
585
586    // Last segment: after loc.next.adj_pt_idx
587    if loc.next.adj_pt_idx as usize + 1 < elem_buff.len() {
588        new_buff.extend_from_slice(&elem_buff[loc.next.adj_pt_idx as usize + 1..]);
589    }
590
591    // Replace old buffer with the new one
592    *elem_buff = new_buff;
593}
594
595#[cfg(test)]
596mod tests {
597    use sqruff_lib::core::test_functions::parse_ansi_string;
598    use sqruff_lib_core::helpers::enter_panic;
599    use sqruff_lib_core::parser::segments::Tables;
600
601    use crate::utils::reflow::sequence::{ReflowSequence, TargetSide};
602
603    #[test]
604    fn test_reflow_sequence_rebreak_root() {
605        let cases = [
606            // Trivial Case
607            ("select 1", "select 1"),
608            // These rely on the default config being for leading operators
609            ("select 1\n+2", "select 1\n+2"),
610            ("select 1+\n2", "select 1\n+ 2"), // NOTE: Implicit respace.
611            ("select\n  1 +\n  2", "select\n  1\n  + 2"),
612            (
613                "select\n  1 +\n  -- comment\n  2",
614                "select\n  1\n  -- comment\n  + 2",
615            ),
616            // These rely on the default config being for trailing commas
617            ("select a,b", "select a,b"),
618            ("select a\n,b", "select a,\nb"),
619            ("select\n  a\n  , b", "select\n  a,\n  b"),
620            ("select\n    a\n    , b", "select\n    a,\n    b"),
621            ("select\n  a\n    , b", "select\n  a,\n    b"),
622            (
623                "select\n  a\n  -- comment\n  , b",
624                "select\n  a,\n  -- comment\n  b",
625            ),
626        ];
627
628        let tables = Tables::default();
629        for (raw_sql_in, raw_sql_out) in cases {
630            let _panic = enter_panic(format!("{raw_sql_in:?}"));
631
632            let root = parse_ansi_string(raw_sql_in);
633            let config = <_>::default();
634            let seq = ReflowSequence::from_root(&root, &config);
635            let new_seq = seq.rebreak(&tables);
636
637            assert_eq!(new_seq.raw(), raw_sql_out);
638        }
639    }
640
641    #[test]
642    fn test_reflow_sequence_rebreak_target() {
643        let cases = [
644            ("select 1+\n(2+3)", 4, "1+\n(", "1\n+ ("),
645            ("select a,\n(b+c)", 4, "a,\n(", "a,\n("),
646            ("select a\n  , (b+c)", 6, "a\n  , (", "a,\n  ("),
647            // Here we don't have enough context to rebreak it so
648            // it should be left unaltered.
649            ("select a,\n(b+c)", 6, ",\n(b", ",\n(b"),
650            // This intentionally targets an incomplete span.
651            ("select a<=b", 4, "a<=", "a<="),
652        ];
653
654        let tables = Tables::default();
655        for (raw_sql_in, target_idx, seq_sql_in, seq_sql_out) in cases {
656            let root = parse_ansi_string(raw_sql_in);
657            let target = &root.get_raw_segments()[target_idx];
658            let config = <_>::default();
659            let seq = ReflowSequence::from_around_target(target, &root, TargetSide::Both, &config);
660
661            assert_eq!(seq.raw(), seq_sql_in);
662
663            let new_seq = seq.rebreak(&tables);
664            assert_eq!(new_seq.raw(), seq_sql_out);
665        }
666    }
667}