sqruff_lib_core/parser/segments/
base.rs

1use std::cell::{Cell, OnceCell};
2use std::fmt::Debug;
3use std::hash::{Hash, Hasher};
4use std::rc::Rc;
5
6use itertools::enumerate;
7use rustc_hash::FxHashMap;
8use smol_str::SmolStr;
9
10use crate::dialects::init::DialectKind;
11use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
12use crate::edit_type::EditType;
13use crate::parser::markers::PositionMarker;
14use crate::parser::segments::fix::{FixPatch, SourceFix};
15use crate::parser::segments::object_reference::{ObjectReferenceKind, ObjectReferenceSegment};
16use crate::segments::AnchorEditInfo;
17use crate::templaters::base::TemplatedFile;
18
19pub struct SegmentBuilder {
20    node_or_token: NodeOrToken,
21}
22
23impl SegmentBuilder {
24    pub fn whitespace(id: u32, raw: &str) -> ErasedSegment {
25        SegmentBuilder::token(id, raw, SyntaxKind::Whitespace).finish()
26    }
27
28    pub fn newline(id: u32, raw: &str) -> ErasedSegment {
29        SegmentBuilder::token(id, raw, SyntaxKind::Newline).finish()
30    }
31
32    pub fn keyword(id: u32, raw: &str) -> ErasedSegment {
33        SegmentBuilder::token(id, raw, SyntaxKind::Keyword).finish()
34    }
35
36    pub fn comma(id: u32) -> ErasedSegment {
37        SegmentBuilder::token(id, ",", SyntaxKind::Comma).finish()
38    }
39
40    pub fn symbol(id: u32, raw: &str) -> ErasedSegment {
41        SegmentBuilder::token(id, raw, SyntaxKind::Symbol).finish()
42    }
43
44    pub fn node(
45        id: u32,
46        syntax_kind: SyntaxKind,
47        dialect: DialectKind,
48        segments: Vec<ErasedSegment>,
49    ) -> Self {
50        SegmentBuilder {
51            node_or_token: NodeOrToken {
52                id,
53                syntax_kind,
54                class_types: class_types(syntax_kind),
55                position_marker: None,
56                code_idx: OnceCell::new(),
57                kind: NodeOrTokenKind::Node(NodeData {
58                    dialect,
59                    segments,
60                    raw: Default::default(),
61                    source_fixes: vec![],
62                    descendant_type_set: Default::default(),
63                    raw_segments_with_ancestors: Default::default(),
64                }),
65                hash: OnceCell::new(),
66            },
67        }
68    }
69
70    pub fn token(id: u32, raw: &str, syntax_kind: SyntaxKind) -> Self {
71        SegmentBuilder {
72            node_or_token: NodeOrToken {
73                id,
74                syntax_kind,
75                code_idx: OnceCell::new(),
76                class_types: class_types(syntax_kind),
77                position_marker: None,
78                kind: NodeOrTokenKind::Token(TokenData { raw: raw.into() }),
79                hash: OnceCell::new(),
80            },
81        }
82    }
83
84    pub fn position_from_segments(mut self) -> Self {
85        let segments = match &self.node_or_token.kind {
86            NodeOrTokenKind::Node(node) => &node.segments[..],
87            NodeOrTokenKind::Token(_) => &[],
88        };
89
90        self.node_or_token.position_marker = pos_marker(segments).into();
91        self
92    }
93
94    pub fn with_position(mut self, position: PositionMarker) -> Self {
95        self.node_or_token.position_marker = Some(position);
96        self
97    }
98
99    pub fn finish(self) -> ErasedSegment {
100        ErasedSegment {
101            value: Rc::new(self.node_or_token),
102        }
103    }
104}
105
106#[derive(Debug, Default)]
107pub struct Tables {
108    counter: Cell<u32>,
109}
110
111impl Tables {
112    pub fn next_id(&self) -> u32 {
113        let id = self.counter.get();
114        self.counter.set(id + 1);
115        id
116    }
117}
118
119#[derive(Debug, Clone)]
120pub struct ErasedSegment {
121    pub(crate) value: Rc<NodeOrToken>,
122}
123
124impl Hash for ErasedSegment {
125    fn hash<H: Hasher>(&self, state: &mut H) {
126        self.hash_value().hash(state);
127    }
128}
129
130impl Eq for ErasedSegment {}
131
132impl ErasedSegment {
133    pub fn raw(&self) -> &SmolStr {
134        match &self.value.kind {
135            NodeOrTokenKind::Node(node) => node.raw.get_or_init(|| {
136                SmolStr::from_iter(self.segments().iter().map(|segment| segment.raw().as_str()))
137            }),
138            NodeOrTokenKind::Token(token) => &token.raw,
139        }
140    }
141
142    pub fn segments(&self) -> &[ErasedSegment] {
143        match &self.value.kind {
144            NodeOrTokenKind::Node(node) => &node.segments,
145            NodeOrTokenKind::Token(_) => &[],
146        }
147    }
148
149    pub fn get_type(&self) -> SyntaxKind {
150        self.value.syntax_kind
151    }
152
153    pub fn is_type(&self, kind: SyntaxKind) -> bool {
154        self.get_type() == kind
155    }
156
157    pub fn is_meta(&self) -> bool {
158        matches!(
159            self.value.syntax_kind,
160            SyntaxKind::Indent | SyntaxKind::Implicit | SyntaxKind::Dedent | SyntaxKind::EndOfFile
161        )
162    }
163
164    pub fn is_code(&self) -> bool {
165        match &self.value.kind {
166            NodeOrTokenKind::Node(node) => node.segments.iter().any(|s| s.is_code()),
167            NodeOrTokenKind::Token(_) => {
168                !self.is_comment() && !self.is_whitespace() && !self.is_meta()
169            }
170        }
171    }
172
173    pub fn get_raw_segments(&self) -> Vec<ErasedSegment> {
174        self.recursive_crawl_all(false)
175            .into_iter()
176            .filter(|it| it.segments().is_empty())
177            .collect()
178    }
179
180    #[cfg(feature = "stringify")]
181    pub fn stringify(&self, code_only: bool) -> String {
182        serde_yaml::to_string(&self.to_serialised(code_only, true)).unwrap()
183    }
184
185    pub fn child(&self, seg_types: &SyntaxSet) -> Option<ErasedSegment> {
186        self.segments()
187            .iter()
188            .find(|seg| seg_types.contains(seg.get_type()))
189            .cloned()
190    }
191
192    pub fn recursive_crawl(
193        &self,
194        types: &SyntaxSet,
195        recurse_into: bool,
196        no_recursive_types: &SyntaxSet,
197        allow_self: bool,
198    ) -> Vec<ErasedSegment> {
199        let mut acc = Vec::new();
200
201        let matches = allow_self && self.class_types().intersects(types);
202        if matches {
203            acc.push(self.clone());
204        }
205
206        if !self.descendant_type_set().intersects(types) {
207            return acc;
208        }
209
210        if recurse_into || !matches {
211            for seg in self.segments() {
212                if no_recursive_types.is_empty() || !no_recursive_types.contains(seg.get_type()) {
213                    let segments =
214                        seg.recursive_crawl(types, recurse_into, no_recursive_types, true);
215                    acc.extend(segments);
216                }
217            }
218        }
219
220        acc
221    }
222}
223
224impl ErasedSegment {
225    #[allow(clippy::new_ret_no_self, clippy::wrong_self_convention)]
226    #[track_caller]
227    pub fn new(&self, segments: Vec<ErasedSegment>) -> ErasedSegment {
228        match &self.value.kind {
229            NodeOrTokenKind::Node(node) => SegmentBuilder::node(
230                self.value.id,
231                self.value.syntax_kind,
232                node.dialect,
233                segments,
234            )
235            .with_position(self.get_position_marker().unwrap().clone())
236            .finish(),
237            NodeOrTokenKind::Token(_) => self.deep_clone(),
238        }
239    }
240
241    fn change_segments(&self, segments: Vec<ErasedSegment>) -> ErasedSegment {
242        let NodeOrTokenKind::Node(node) = &self.value.kind else {
243            unimplemented!()
244        };
245
246        ErasedSegment {
247            value: Rc::new(NodeOrToken {
248                id: self.value.id,
249                syntax_kind: self.value.syntax_kind,
250                class_types: self.value.class_types.clone(),
251                position_marker: None,
252                code_idx: OnceCell::new(),
253                kind: NodeOrTokenKind::Node(NodeData {
254                    dialect: node.dialect,
255                    segments,
256                    raw: node.raw.clone(),
257                    source_fixes: node.source_fixes.clone(),
258                    descendant_type_set: node.descendant_type_set.clone(),
259                    raw_segments_with_ancestors: node.raw_segments_with_ancestors.clone(),
260                }),
261                hash: OnceCell::new(),
262            }),
263        }
264    }
265
266    pub fn indent_val(&self) -> i8 {
267        self.value.syntax_kind.indent_val()
268    }
269
270    pub fn can_start_end_non_code(&self) -> bool {
271        matches!(
272            self.value.syntax_kind,
273            SyntaxKind::File | SyntaxKind::Unparsable
274        )
275    }
276
277    pub(crate) fn dialect(&self) -> DialectKind {
278        match &self.value.kind {
279            NodeOrTokenKind::Node(node) => node.dialect,
280            NodeOrTokenKind::Token(_) => todo!(),
281        }
282    }
283
284    pub fn get_start_loc(&self) -> (usize, usize) {
285        match self.get_position_marker() {
286            Some(pos_marker) => pos_marker.working_loc(),
287            None => unreachable!("{self:?} has no PositionMarker"),
288        }
289    }
290
291    pub fn get_end_loc(&self) -> (usize, usize) {
292        match self.get_position_marker() {
293            Some(pos_marker) => pos_marker.working_loc_after(self.raw()),
294            None => {
295                unreachable!("{self:?} has no PositionMarker")
296            }
297        }
298    }
299
300    pub fn is_templated(&self) -> bool {
301        if let Some(pos_marker) = self.get_position_marker() {
302            pos_marker.source_slice.start != pos_marker.source_slice.end && !pos_marker.is_literal()
303        } else {
304            panic!("PosMarker must be set");
305        }
306    }
307
308    pub fn iter_segments(&self, expanding: &SyntaxSet, pass_through: bool) -> Vec<ErasedSegment> {
309        let capacity = if expanding.is_empty() {
310            self.segments().len()
311        } else {
312            0
313        };
314        let mut result = Vec::with_capacity(capacity);
315        for segment in self.segments() {
316            if expanding.contains(segment.get_type()) {
317                let expanding = if pass_through {
318                    expanding
319                } else {
320                    &SyntaxSet::EMPTY
321                };
322                result.append(&mut segment.iter_segments(expanding, false));
323            } else {
324                result.push(segment.clone());
325            }
326        }
327        result
328    }
329
330    pub(crate) fn code_indices(&self) -> Rc<Vec<usize>> {
331        self.value
332            .code_idx
333            .get_or_init(|| {
334                Rc::from(
335                    self.segments()
336                        .iter()
337                        .enumerate()
338                        .filter(|(_, seg)| seg.is_code())
339                        .map(|(idx, _)| idx)
340                        .collect::<Vec<_>>(),
341                )
342            })
343            .clone()
344    }
345
346    pub fn children(
347        &self,
348        seg_types: &'static SyntaxSet,
349    ) -> impl Iterator<Item = &ErasedSegment> + '_ {
350        self.segments()
351            .iter()
352            .filter(move |seg| seg_types.contains(seg.get_type()))
353    }
354
355    pub fn iter_patches(&self, templated_file: &TemplatedFile) -> Vec<FixPatch> {
356        let mut acc = Vec::new();
357
358        let templated_raw = &templated_file.templated_str.as_ref().unwrap()
359            [self.get_position_marker().unwrap().templated_slice.clone()];
360        if self.raw() == templated_raw {
361            acc.extend(self.iter_source_fix_patches(templated_file));
362            return acc;
363        }
364
365        if self.get_position_marker().is_none() {
366            return Vec::new();
367        }
368
369        let pos_marker = self.get_position_marker().unwrap();
370        if pos_marker.is_literal() {
371            acc.extend(self.iter_source_fix_patches(templated_file));
372            acc.push(FixPatch::new(
373                pos_marker.templated_slice.clone(),
374                self.raw().clone(),
375                // SyntaxKind::Literal.into(),
376                pos_marker.source_slice.clone(),
377                templated_file.templated_str.as_ref().unwrap()[pos_marker.templated_slice.clone()]
378                    .to_string(),
379                templated_file.source_str[pos_marker.source_slice.clone()].to_string(),
380            ));
381        } else if self.segments().is_empty() {
382            return acc;
383        } else {
384            let mut segments = self.segments();
385
386            while !segments.is_empty()
387                && matches!(
388                    segments.last().unwrap().get_type(),
389                    SyntaxKind::EndOfFile
390                        | SyntaxKind::Indent
391                        | SyntaxKind::Dedent
392                        | SyntaxKind::Implicit
393                )
394            {
395                segments = &segments[..segments.len() - 1];
396            }
397
398            let pos = self.get_position_marker().unwrap();
399            let mut source_idx = pos.source_slice.start;
400            let mut templated_idx = pos.templated_slice.start;
401            let mut insert_buff = String::new();
402
403            for segment in segments {
404                let pos_marker = segment.get_position_marker().unwrap();
405                if !segment.raw().is_empty() && pos_marker.is_point() {
406                    insert_buff.push_str(segment.raw().as_ref());
407                    continue;
408                }
409
410                let start_diff = pos_marker.templated_slice.start - templated_idx;
411
412                if start_diff > 0 || !insert_buff.is_empty() {
413                    let fixed_raw = std::mem::take(&mut insert_buff);
414                    let raw_segments = segment.get_raw_segments();
415                    let first_segment_pos = raw_segments[0].get_position_marker().unwrap();
416
417                    acc.push(FixPatch::new(
418                        templated_idx..first_segment_pos.templated_slice.start,
419                        fixed_raw.into(),
420                        source_idx..first_segment_pos.source_slice.start,
421                        String::new(),
422                        String::new(),
423                    ));
424                }
425
426                acc.extend(segment.iter_patches(templated_file));
427
428                source_idx = pos_marker.source_slice.end;
429                templated_idx = pos_marker.templated_slice.end;
430            }
431
432            let end_diff = pos.templated_slice.end - templated_idx;
433            if end_diff != 0 || !insert_buff.is_empty() {
434                let source_slice = source_idx..pos.source_slice.end;
435                let templated_slice = templated_idx..pos.templated_slice.end;
436
437                let templated_str = templated_file.templated_str.as_ref().unwrap()
438                    [templated_slice.clone()]
439                .to_owned();
440                let source_str = templated_file.source_str[source_slice.clone()].to_owned();
441
442                acc.push(FixPatch::new(
443                    templated_slice,
444                    insert_buff.into(),
445                    source_slice,
446                    templated_str,
447                    source_str,
448                ));
449            }
450        }
451
452        acc
453    }
454
455    pub fn descendant_type_set(&self) -> &SyntaxSet {
456        match &self.value.kind {
457            NodeOrTokenKind::Node(node) => node.descendant_type_set.get_or_init(|| {
458                self.segments()
459                    .iter()
460                    .flat_map(|segment| {
461                        segment
462                            .descendant_type_set()
463                            .clone()
464                            .union(segment.class_types())
465                    })
466                    .collect()
467            }),
468            NodeOrTokenKind::Token(_) => const { &SyntaxSet::EMPTY },
469        }
470    }
471
472    pub fn is_comment(&self) -> bool {
473        matches!(
474            self.value.syntax_kind,
475            SyntaxKind::Comment | SyntaxKind::InlineComment | SyntaxKind::BlockComment
476        )
477    }
478
479    pub fn is_whitespace(&self) -> bool {
480        matches!(
481            self.value.syntax_kind,
482            SyntaxKind::Whitespace | SyntaxKind::Newline
483        )
484    }
485
486    pub fn is_indent(&self) -> bool {
487        matches!(
488            self.value.syntax_kind,
489            SyntaxKind::Indent | SyntaxKind::Implicit | SyntaxKind::Dedent
490        )
491    }
492
493    pub fn get_position_marker(&self) -> Option<&PositionMarker> {
494        self.value.position_marker.as_ref()
495    }
496
497    pub(crate) fn iter_source_fix_patches(&self, templated_file: &TemplatedFile) -> Vec<FixPatch> {
498        let source_fixes = self.get_source_fixes();
499        let mut patches = Vec::with_capacity(source_fixes.len());
500
501        for source_fix in &source_fixes {
502            patches.push(FixPatch::new(
503                source_fix.templated_slice.clone(),
504                source_fix.edit.clone(),
505                // String::from("source"),
506                source_fix.source_slice.clone(),
507                templated_file.templated_str.clone().unwrap()[source_fix.templated_slice.clone()]
508                    .to_string(),
509                templated_file.source_str[source_fix.source_slice.clone()].to_string(),
510            ));
511        }
512
513        patches
514    }
515
516    pub fn id(&self) -> u32 {
517        self.value.id
518    }
519
520    /// Return any source fixes as list.
521    pub fn get_source_fixes(&self) -> Vec<SourceFix> {
522        match &self.value.kind {
523            NodeOrTokenKind::Node(node) => node.source_fixes.clone(),
524            NodeOrTokenKind::Token(_) => Vec::new(),
525        }
526    }
527
528    pub fn edit(
529        &self,
530        id: u32,
531        raw: Option<String>,
532        _source_fixes: Option<Vec<SourceFix>>,
533    ) -> ErasedSegment {
534        match &self.value.kind {
535            NodeOrTokenKind::Node(_node) => {
536                todo!()
537            }
538            NodeOrTokenKind::Token(token) => {
539                let raw = raw.as_deref().unwrap_or(token.raw.as_ref());
540                SegmentBuilder::token(id, raw, self.value.syntax_kind)
541                    .with_position(self.get_position_marker().unwrap().clone())
542                    .finish()
543            }
544        }
545    }
546
547    pub fn class_types(&self) -> &SyntaxSet {
548        &self.value.class_types
549    }
550
551    pub(crate) fn first_non_whitespace_segment_raw_upper(&self) -> Option<String> {
552        for seg in self.get_raw_segments() {
553            if !seg.raw().is_empty() {
554                return Some(seg.raw().to_uppercase());
555            }
556        }
557        None
558    }
559
560    pub fn is(&self, other: &ErasedSegment) -> bool {
561        Rc::ptr_eq(&self.value, &other.value)
562    }
563
564    pub fn addr(&self) -> usize {
565        Rc::as_ptr(&self.value).addr()
566    }
567
568    pub fn direct_descendant_type_set(&self) -> SyntaxSet {
569        self.segments()
570            .iter()
571            .fold(SyntaxSet::EMPTY, |set, it| set.union(it.class_types()))
572    }
573
574    pub fn is_keyword(&self, p0: &str) -> bool {
575        self.is_type(SyntaxKind::Keyword) && self.raw().eq_ignore_ascii_case(p0)
576    }
577
578    pub fn hash_value(&self) -> u64 {
579        *self.value.hash.get_or_init(|| {
580            let mut hasher = ahash::AHasher::default();
581            self.get_type().hash(&mut hasher);
582            self.raw().hash(&mut hasher);
583
584            if let Some(marker) = &self.get_position_marker() {
585                marker.source_position().hash(&mut hasher);
586            } else {
587                None::<usize>.hash(&mut hasher);
588            }
589
590            hasher.finish()
591        })
592    }
593
594    pub fn deep_clone(&self) -> Self {
595        Self {
596            value: Rc::new(self.value.as_ref().clone()),
597        }
598    }
599
600    #[track_caller]
601    pub(crate) fn get_mut(&mut self) -> &mut NodeOrToken {
602        Rc::get_mut(&mut self.value).unwrap()
603    }
604
605    #[track_caller]
606    pub(crate) fn make_mut(&mut self) -> &mut NodeOrToken {
607        Rc::make_mut(&mut self.value)
608    }
609
610    pub fn reference(&self) -> ObjectReferenceSegment {
611        ObjectReferenceSegment(
612            self.clone(),
613            match self.get_type() {
614                SyntaxKind::TableReference => ObjectReferenceKind::Table,
615                SyntaxKind::WildcardIdentifier => ObjectReferenceKind::WildcardIdentifier,
616                _ => ObjectReferenceKind::Object,
617            },
618        )
619    }
620
621    pub fn recursive_crawl_all(&self, reverse: bool) -> Vec<ErasedSegment> {
622        let mut result = Vec::with_capacity(self.segments().len() + 1);
623
624        if reverse {
625            for seg in self.segments().iter().rev() {
626                result.append(&mut seg.recursive_crawl_all(reverse));
627            }
628            result.push(self.clone());
629        } else {
630            result.push(self.clone());
631            for seg in self.segments() {
632                result.append(&mut seg.recursive_crawl_all(reverse));
633            }
634        }
635
636        result
637    }
638
639    pub fn raw_segments_with_ancestors(&self) -> &[(ErasedSegment, Vec<PathStep>)] {
640        match &self.value.kind {
641            NodeOrTokenKind::Node(node) => node.raw_segments_with_ancestors.get_or_init(|| {
642                let mut buffer: Vec<(ErasedSegment, Vec<PathStep>)> =
643                    Vec::with_capacity(self.segments().len());
644                let code_idxs = self.code_indices();
645
646                for (idx, seg) in self.segments().iter().enumerate() {
647                    let new_step = vec![PathStep {
648                        segment: self.clone(),
649                        idx,
650                        len: self.segments().len(),
651                        code_idxs: code_idxs.clone(),
652                    }];
653
654                    // Use seg.get_segments().is_empty() as a workaround to check if the segment is
655                    // a SyntaxKind::Raw type. In the original Python code, this was achieved
656                    // using seg.is_type(SyntaxKind::Raw). Here, we assume that a SyntaxKind::Raw
657                    // segment is characterized by having no sub-segments.
658
659                    if seg.segments().is_empty() {
660                        buffer.push((seg.clone(), new_step));
661                    } else {
662                        let extended =
663                            seg.raw_segments_with_ancestors()
664                                .iter()
665                                .map(|(raw_seg, stack)| {
666                                    let mut new_step = new_step.clone();
667                                    new_step.extend_from_slice(stack);
668                                    (raw_seg.clone(), new_step)
669                                });
670
671                        buffer.extend(extended);
672                    }
673                }
674
675                buffer
676            }),
677            NodeOrTokenKind::Token(_) => &[],
678        }
679    }
680
681    pub fn path_to(&self, other: &ErasedSegment) -> Vec<PathStep> {
682        let midpoint = other;
683
684        for (idx, seg) in enumerate(self.segments()) {
685            let mut steps = vec![PathStep {
686                segment: self.clone(),
687                idx,
688                len: self.segments().len(),
689                code_idxs: self.code_indices(),
690            }];
691
692            if seg.eq(midpoint) {
693                return steps;
694            }
695
696            let res = seg.path_to(midpoint);
697
698            if !res.is_empty() {
699                steps.extend(res);
700                return steps;
701            }
702        }
703
704        Vec::new()
705    }
706
707    pub fn apply_fixes(
708        &self,
709        fixes: &mut FxHashMap<u32, AnchorEditInfo>,
710    ) -> (ErasedSegment, Vec<ErasedSegment>, Vec<ErasedSegment>, bool) {
711        if fixes.is_empty() || self.segments().is_empty() {
712            return (self.clone(), Vec::new(), Vec::new(), true);
713        }
714
715        let mut seg_buffer = Vec::new();
716        let mut has_applied_fixes = false;
717        let mut _requires_validate = false;
718
719        for seg in self.segments() {
720            // Look for uuid match.
721            // This handles potential positioning ambiguity.
722
723            let Some(mut anchor_info) = fixes.remove(&seg.id()) else {
724                seg_buffer.push(seg.clone());
725                continue;
726            };
727
728            if anchor_info.fixes.len() == 2
729                && anchor_info.fixes[0].edit_type == EditType::CreateAfter
730            {
731                anchor_info.fixes.reverse();
732            }
733
734            let fixes_count = anchor_info.fixes.len();
735            for mut lint_fix in anchor_info.fixes {
736                has_applied_fixes = true;
737
738                // Deletes are easy.
739                if lint_fix.edit_type == EditType::Delete {
740                    // We're just getting rid of this segment.
741                    _requires_validate = true;
742                    // NOTE: We don't add the segment in this case.
743                    continue;
744                }
745
746                // Otherwise it must be a replace or a create.
747                assert!(matches!(
748                    lint_fix.edit_type,
749                    EditType::Replace | EditType::CreateBefore | EditType::CreateAfter
750                ));
751
752                if lint_fix.edit_type == EditType::CreateAfter && fixes_count == 1 {
753                    // In the case of a creation after that is not part
754                    // of a create_before/create_after pair, also add
755                    // this segment before the edit.
756                    seg_buffer.push(seg.clone());
757                }
758
759                let mut consumed_pos = false;
760                for mut s in std::mem::take(&mut lint_fix.edit) {
761                    if lint_fix.edit_type == EditType::Replace
762                        && !consumed_pos
763                        && s.raw() == seg.raw()
764                    {
765                        consumed_pos = true;
766                        s.make_mut()
767                            .set_position_marker(seg.get_position_marker().cloned());
768                    }
769
770                    seg_buffer.push(s);
771                }
772
773                if !(lint_fix.edit_type == EditType::Replace
774                    && lint_fix.edit.len() == 1
775                    && lint_fix.edit[0].class_types() == seg.class_types())
776                {
777                    _requires_validate = true;
778                }
779
780                if lint_fix.edit_type == EditType::CreateBefore {
781                    seg_buffer.push(seg.clone());
782                }
783            }
784        }
785
786        if has_applied_fixes {
787            seg_buffer =
788                position_segments(&seg_buffer, self.get_position_marker().as_ref().unwrap());
789        }
790
791        let seg_queue = seg_buffer;
792        let mut seg_buffer = Vec::new();
793        for seg in seg_queue {
794            let (s, pre, post, validated) = seg.apply_fixes(fixes);
795
796            seg_buffer.extend(pre);
797            seg_buffer.push(s);
798            seg_buffer.extend(post);
799
800            if !validated {
801                _requires_validate = true;
802            }
803        }
804
805        let seg_buffer =
806            position_segments(&seg_buffer, self.get_position_marker().as_ref().unwrap());
807        (self.new(seg_buffer), Vec::new(), Vec::new(), false)
808    }
809}
810
811#[cfg(any(test, feature = "serde"))]
812pub mod serde {
813    use serde::ser::SerializeMap;
814    use serde::{Deserialize, Serialize};
815
816    use crate::parser::segments::base::ErasedSegment;
817
818    #[derive(Serialize, Deserialize)]
819    #[serde(untagged)]
820    pub enum SerialisedSegmentValue {
821        Single(String),
822        Nested(Vec<TupleSerialisedSegment>),
823    }
824
825    #[derive(Deserialize)]
826    pub struct TupleSerialisedSegment(String, SerialisedSegmentValue);
827
828    impl Serialize for TupleSerialisedSegment {
829        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
830        where
831            S: serde::Serializer,
832        {
833            let mut map = serializer.serialize_map(None)?;
834            map.serialize_key(&self.0)?;
835            map.serialize_value(&self.1)?;
836            map.end()
837        }
838    }
839
840    impl TupleSerialisedSegment {
841        pub fn sinlge(key: String, value: String) -> Self {
842            Self(key, SerialisedSegmentValue::Single(value))
843        }
844
845        pub fn nested(key: String, segments: Vec<TupleSerialisedSegment>) -> Self {
846            Self(key, SerialisedSegmentValue::Nested(segments))
847        }
848    }
849
850    impl ErasedSegment {
851        pub fn to_serialised(&self, code_only: bool, show_raw: bool) -> TupleSerialisedSegment {
852            if show_raw && self.segments().is_empty() {
853                TupleSerialisedSegment::sinlge(
854                    self.get_type().as_str().to_string(),
855                    self.raw().to_string(),
856                )
857            } else if code_only {
858                let segments = self
859                    .segments()
860                    .iter()
861                    .filter(|seg| seg.is_code() && !seg.is_meta())
862                    .map(|seg| seg.to_serialised(code_only, show_raw))
863                    .collect::<Vec<_>>();
864
865                TupleSerialisedSegment::nested(self.get_type().as_str().to_string(), segments)
866            } else {
867                let segments = self
868                    .segments()
869                    .iter()
870                    .map(|seg| seg.to_serialised(code_only, show_raw))
871                    .collect::<Vec<_>>();
872
873                TupleSerialisedSegment::nested(self.get_type().as_str().to_string(), segments)
874            }
875        }
876    }
877}
878
879impl PartialEq for ErasedSegment {
880    fn eq(&self, other: &Self) -> bool {
881        if self.id() == other.id() {
882            return true;
883        }
884
885        let pos_self = self.get_position_marker();
886        let pos_other = other.get_position_marker();
887        if let Some((pos_self, pos_other)) = pos_self.zip(pos_other) {
888            self.get_type() == other.get_type()
889                && pos_self.working_loc() == pos_other.working_loc()
890                && self.raw() == other.raw()
891        } else {
892            false
893        }
894    }
895}
896
897pub fn position_segments(
898    segments: &[ErasedSegment],
899    parent_pos: &PositionMarker,
900) -> Vec<ErasedSegment> {
901    if segments.is_empty() {
902        return Vec::new();
903    }
904
905    let (mut line_no, mut line_pos) = { (parent_pos.working_line_no, parent_pos.working_line_pos) };
906
907    let mut segment_buffer: Vec<ErasedSegment> = Vec::new();
908    for (idx, segment) in enumerate(segments) {
909        let old_position = segment.get_position_marker();
910
911        let mut new_position = match old_position {
912            Some(pos_marker) => pos_marker.clone(),
913            None => {
914                let start_point = if idx > 0 {
915                    let prev_seg = segment_buffer[idx - 1].clone();
916                    Some(prev_seg.get_position_marker().unwrap().end_point_marker())
917                } else {
918                    Some(parent_pos.start_point_marker())
919                };
920
921                let mut end_point = None;
922                for fwd_seg in &segments[idx + 1..] {
923                    if fwd_seg.get_position_marker().is_some() {
924                        end_point = Some(
925                            fwd_seg.get_raw_segments()[0]
926                                .get_position_marker()
927                                .unwrap()
928                                .start_point_marker(),
929                        );
930                        break;
931                    }
932                }
933
934                if let Some((start_point, end_point)) = start_point
935                    .as_ref()
936                    .zip(end_point.as_ref())
937                    .filter(|(start_point, end_point)| start_point != end_point)
938                {
939                    PositionMarker::from_points(start_point, end_point)
940                } else if let Some(start_point) = start_point.as_ref() {
941                    start_point.clone()
942                } else if let Some(end_point) = end_point.as_ref() {
943                    end_point.clone()
944                } else {
945                    unimplemented!("Unable to position new segment")
946                }
947            }
948        };
949
950        new_position = new_position.with_working_position(line_no, line_pos);
951        (line_no, line_pos) = PositionMarker::infer_next_position(segment.raw(), line_no, line_pos);
952
953        let mut new_seg = if !segment.segments().is_empty() && old_position != Some(&new_position) {
954            let child_segments = position_segments(segment.segments(), &new_position);
955            segment.change_segments(child_segments)
956        } else {
957            segment.deep_clone()
958        };
959
960        new_seg.get_mut().set_position_marker(new_position.into());
961        segment_buffer.push(new_seg);
962    }
963
964    segment_buffer
965}
966
967#[derive(Debug, Clone)]
968pub struct NodeOrToken {
969    id: u32,
970    syntax_kind: SyntaxKind,
971    class_types: SyntaxSet,
972    position_marker: Option<PositionMarker>,
973    kind: NodeOrTokenKind,
974    code_idx: OnceCell<Rc<Vec<usize>>>,
975    hash: OnceCell<u64>,
976}
977
978#[derive(Debug, Clone)]
979pub enum NodeOrTokenKind {
980    Node(NodeData),
981    Token(TokenData),
982}
983
984impl NodeOrToken {
985    pub fn set_position_marker(&mut self, position_marker: Option<PositionMarker>) {
986        self.position_marker = position_marker;
987    }
988
989    pub fn set_id(&mut self, id: u32) {
990        self.id = id;
991    }
992}
993
994#[derive(Debug, Clone)]
995pub struct NodeData {
996    dialect: DialectKind,
997    segments: Vec<ErasedSegment>,
998    raw: OnceCell<SmolStr>,
999    source_fixes: Vec<SourceFix>,
1000    descendant_type_set: OnceCell<SyntaxSet>,
1001    raw_segments_with_ancestors: OnceCell<Vec<(ErasedSegment, Vec<PathStep>)>>,
1002}
1003
1004#[derive(Debug, Clone, PartialEq)]
1005pub struct TokenData {
1006    raw: SmolStr,
1007}
1008
1009#[track_caller]
1010pub fn pos_marker(segments: &[ErasedSegment]) -> PositionMarker {
1011    let markers = segments.iter().filter_map(|seg| seg.get_position_marker());
1012
1013    PositionMarker::from_child_markers(markers)
1014}
1015
1016#[derive(Debug, Clone)]
1017pub struct PathStep {
1018    pub segment: ErasedSegment,
1019    pub idx: usize,
1020    pub len: usize,
1021    pub code_idxs: Rc<Vec<usize>>,
1022}
1023
1024fn class_types(syntax_kind: SyntaxKind) -> SyntaxSet {
1025    match syntax_kind {
1026        SyntaxKind::ColumnReference => SyntaxSet::new(&[SyntaxKind::ObjectReference, syntax_kind]),
1027        SyntaxKind::WildcardIdentifier => {
1028            SyntaxSet::new(&[SyntaxKind::WildcardIdentifier, SyntaxKind::ObjectReference])
1029        }
1030        SyntaxKind::TableReference => SyntaxSet::new(&[SyntaxKind::ObjectReference, syntax_kind]),
1031        _ => SyntaxSet::single(syntax_kind),
1032    }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037    use super::*;
1038    use crate::lint_fix::LintFix;
1039    use crate::linter::compute_anchor_edit_info;
1040    use crate::parser::segments::test_functions::{raw_seg, raw_segments};
1041
1042    #[test]
1043    /// Test comparison of raw segments.
1044    fn test_parser_base_segments_raw_compare() {
1045        let template: TemplatedFile = "foobar".into();
1046        let rs1 = SegmentBuilder::token(0, "foobar", SyntaxKind::Word)
1047            .with_position(PositionMarker::new(
1048                0..6,
1049                0..6,
1050                template.clone(),
1051                None,
1052                None,
1053            ))
1054            .finish();
1055        let rs2 = SegmentBuilder::token(0, "foobar", SyntaxKind::Word)
1056            .with_position(PositionMarker::new(
1057                0..6,
1058                0..6,
1059                template.clone(),
1060                None,
1061                None,
1062            ))
1063            .finish();
1064
1065        assert_eq!(rs1, rs2)
1066    }
1067
1068    #[test]
1069    // TODO Implement
1070    /// Test raw segments behave as expected.
1071    fn test_parser_base_segments_raw() {
1072        let raw_seg = raw_seg();
1073
1074        assert_eq!(raw_seg.raw(), "foobar");
1075    }
1076
1077    #[test]
1078    /// Test BaseSegment.compute_anchor_edit_info().
1079    fn test_parser_base_segments_compute_anchor_edit_info() {
1080        let raw_segs = raw_segments();
1081        let tables = Tables::default();
1082
1083        // Construct a fix buffer, intentionally with:
1084        // - one duplicate.
1085        // - two different incompatible fixes on the same segment.
1086        let fixes = vec![
1087            LintFix::replace(
1088                raw_segs[0].clone(),
1089                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1090                None,
1091            ),
1092            LintFix::replace(
1093                raw_segs[0].clone(),
1094                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1095                None,
1096            ),
1097            LintFix::replace(
1098                raw_segs[0].clone(),
1099                vec![raw_segs[0].edit(tables.next_id(), Some("b".to_string()), None)],
1100                None,
1101            ),
1102        ];
1103
1104        let anchor_edit_info = compute_anchor_edit_info(fixes.into_iter());
1105
1106        // Check the target segment is the only key we have.
1107        assert_eq!(
1108            anchor_edit_info.keys().collect::<Vec<_>>(),
1109            vec![&raw_segs[0].id()]
1110        );
1111
1112        let anchor_info = anchor_edit_info.get(&raw_segs[0].id()).unwrap();
1113
1114        // Check that the duplicate as been deduplicated i.e. this isn't 3.
1115        assert_eq!(anchor_info.replace, 2);
1116
1117        // Check the fixes themselves.
1118        //   Note: There's no duplicated first fix.
1119        assert_eq!(
1120            anchor_info.fixes[0],
1121            LintFix::replace(
1122                raw_segs[0].clone(),
1123                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1124                None,
1125            )
1126        );
1127        assert_eq!(
1128            anchor_info.fixes[1],
1129            LintFix::replace(
1130                raw_segs[0].clone(),
1131                vec![raw_segs[0].edit(tables.next_id(), Some("b".to_string()), None)],
1132                None,
1133            )
1134        );
1135
1136        // Check the first replace
1137        assert_eq!(
1138            anchor_info.fixes[anchor_info.first_replace.unwrap()],
1139            LintFix::replace(
1140                raw_segs[0].clone(),
1141                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1142                None,
1143            )
1144        );
1145    }
1146}