sqruff_lib_core/parser/
segments.rs

1pub mod bracketed;
2pub mod file;
3pub mod fix;
4pub mod from;
5pub mod generator;
6pub mod join;
7pub mod meta;
8pub mod object_reference;
9pub mod select;
10pub mod test_functions;
11
12use std::cell::{Cell, OnceCell};
13use std::fmt::Debug;
14use std::hash::{Hash, Hasher};
15use std::rc::Rc;
16
17use itertools::enumerate;
18use rustc_hash::FxHashMap;
19use smol_str::SmolStr;
20
21use crate::dialects::init::DialectKind;
22use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
23use crate::edit_type::EditType;
24use crate::parser::markers::PositionMarker;
25use crate::parser::segments::fix::{FixPatch, SourceFix};
26use crate::parser::segments::object_reference::{ObjectReferenceKind, ObjectReferenceSegment};
27use crate::segments::AnchorEditInfo;
28use crate::templaters::TemplatedFile;
29
30pub struct SegmentBuilder {
31    node_or_token: NodeOrToken,
32}
33
34impl SegmentBuilder {
35    pub fn whitespace(id: u32, raw: &str) -> ErasedSegment {
36        SegmentBuilder::token(id, raw, SyntaxKind::Whitespace).finish()
37    }
38
39    pub fn newline(id: u32, raw: &str) -> ErasedSegment {
40        SegmentBuilder::token(id, raw, SyntaxKind::Newline).finish()
41    }
42
43    pub fn keyword(id: u32, raw: &str) -> ErasedSegment {
44        SegmentBuilder::token(id, raw, SyntaxKind::Keyword).finish()
45    }
46
47    pub fn comma(id: u32) -> ErasedSegment {
48        SegmentBuilder::token(id, ",", SyntaxKind::Comma).finish()
49    }
50
51    pub fn symbol(id: u32, raw: &str) -> ErasedSegment {
52        SegmentBuilder::token(id, raw, SyntaxKind::Symbol).finish()
53    }
54
55    pub fn node(
56        id: u32,
57        syntax_kind: SyntaxKind,
58        dialect: DialectKind,
59        segments: Vec<ErasedSegment>,
60    ) -> Self {
61        SegmentBuilder {
62            node_or_token: NodeOrToken {
63                id,
64                syntax_kind,
65                class_types: class_types(syntax_kind),
66                position_marker: None,
67                code_idx: OnceCell::new(),
68                kind: NodeOrTokenKind::Node(NodeData {
69                    dialect,
70                    segments,
71                    raw: Default::default(),
72                    source_fixes: vec![],
73                    descendant_type_set: Default::default(),
74                    raw_segments_with_ancestors: Default::default(),
75                }),
76                hash: OnceCell::new(),
77            },
78        }
79    }
80
81    pub fn token(id: u32, raw: &str, syntax_kind: SyntaxKind) -> Self {
82        SegmentBuilder {
83            node_or_token: NodeOrToken {
84                id,
85                syntax_kind,
86                code_idx: OnceCell::new(),
87                class_types: class_types(syntax_kind),
88                position_marker: None,
89                kind: NodeOrTokenKind::Token(TokenData { raw: raw.into() }),
90                hash: OnceCell::new(),
91            },
92        }
93    }
94
95    pub fn position_from_segments(mut self) -> Self {
96        let segments = match &self.node_or_token.kind {
97            NodeOrTokenKind::Node(node) => &node.segments[..],
98            NodeOrTokenKind::Token(_) => &[],
99        };
100
101        self.node_or_token.position_marker = pos_marker(segments).into();
102        self
103    }
104
105    pub fn with_position(mut self, position: PositionMarker) -> Self {
106        self.node_or_token.position_marker = Some(position);
107        self
108    }
109
110    pub fn finish(self) -> ErasedSegment {
111        ErasedSegment {
112            value: Rc::new(self.node_or_token),
113        }
114    }
115}
116
117#[derive(Debug, Default)]
118pub struct Tables {
119    counter: Cell<u32>,
120}
121
122impl Tables {
123    pub fn next_id(&self) -> u32 {
124        let id = self.counter.get();
125        self.counter.set(id + 1);
126        id
127    }
128}
129
130#[derive(Debug, Clone)]
131pub struct ErasedSegment {
132    pub(crate) value: Rc<NodeOrToken>,
133}
134
135impl Hash for ErasedSegment {
136    fn hash<H: Hasher>(&self, state: &mut H) {
137        self.hash_value().hash(state);
138    }
139}
140
141impl Eq for ErasedSegment {}
142
143impl ErasedSegment {
144    pub fn raw(&self) -> &SmolStr {
145        match &self.value.kind {
146            NodeOrTokenKind::Node(node) => node.raw.get_or_init(|| {
147                SmolStr::from_iter(self.segments().iter().map(|segment| segment.raw().as_str()))
148            }),
149            NodeOrTokenKind::Token(token) => &token.raw,
150        }
151    }
152
153    pub fn segments(&self) -> &[ErasedSegment] {
154        match &self.value.kind {
155            NodeOrTokenKind::Node(node) => &node.segments,
156            NodeOrTokenKind::Token(_) => &[],
157        }
158    }
159
160    pub fn get_type(&self) -> SyntaxKind {
161        self.value.syntax_kind
162    }
163
164    pub fn is_type(&self, kind: SyntaxKind) -> bool {
165        self.get_type() == kind
166    }
167
168    pub fn is_meta(&self) -> bool {
169        matches!(
170            self.value.syntax_kind,
171            SyntaxKind::Indent | SyntaxKind::Implicit | SyntaxKind::Dedent | SyntaxKind::EndOfFile
172        )
173    }
174
175    pub fn is_code(&self) -> bool {
176        match &self.value.kind {
177            NodeOrTokenKind::Node(node) => node.segments.iter().any(|s| s.is_code()),
178            NodeOrTokenKind::Token(_) => {
179                !self.is_comment() && !self.is_whitespace() && !self.is_meta()
180            }
181        }
182    }
183
184    pub fn get_raw_segments(&self) -> Vec<ErasedSegment> {
185        self.recursive_crawl_all(false)
186            .into_iter()
187            .filter(|it| it.segments().is_empty())
188            .collect()
189    }
190
191    #[cfg(feature = "stringify")]
192    pub fn stringify(&self, code_only: bool) -> String {
193        serde_yaml::to_string(&self.to_serialised(code_only, true)).unwrap()
194    }
195
196    pub fn child(&self, seg_types: &SyntaxSet) -> Option<ErasedSegment> {
197        self.segments()
198            .iter()
199            .find(|seg| seg_types.contains(seg.get_type()))
200            .cloned()
201    }
202
203    pub fn recursive_crawl(
204        &self,
205        types: &SyntaxSet,
206        recurse_into: bool,
207        no_recursive_types: &SyntaxSet,
208        allow_self: bool,
209    ) -> Vec<ErasedSegment> {
210        let mut acc = Vec::new();
211
212        let matches = allow_self && self.class_types().intersects(types);
213        if matches {
214            acc.push(self.clone());
215        }
216
217        if !self.descendant_type_set().intersects(types) {
218            return acc;
219        }
220
221        if recurse_into || !matches {
222            for seg in self.segments() {
223                if no_recursive_types.is_empty() || !no_recursive_types.contains(seg.get_type()) {
224                    let segments =
225                        seg.recursive_crawl(types, recurse_into, no_recursive_types, true);
226                    acc.extend(segments);
227                }
228            }
229        }
230
231        acc
232    }
233}
234
235impl ErasedSegment {
236    #[allow(clippy::new_ret_no_self, clippy::wrong_self_convention)]
237    #[track_caller]
238    pub fn new(&self, segments: Vec<ErasedSegment>) -> ErasedSegment {
239        match &self.value.kind {
240            NodeOrTokenKind::Node(node) => SegmentBuilder::node(
241                self.value.id,
242                self.value.syntax_kind,
243                node.dialect,
244                segments,
245            )
246            .with_position(self.get_position_marker().unwrap().clone())
247            .finish(),
248            NodeOrTokenKind::Token(_) => self.deep_clone(),
249        }
250    }
251
252    fn change_segments(&self, segments: Vec<ErasedSegment>) -> ErasedSegment {
253        let NodeOrTokenKind::Node(node) = &self.value.kind else {
254            unimplemented!()
255        };
256
257        ErasedSegment {
258            value: Rc::new(NodeOrToken {
259                id: self.value.id,
260                syntax_kind: self.value.syntax_kind,
261                class_types: self.value.class_types.clone(),
262                position_marker: None,
263                code_idx: OnceCell::new(),
264                kind: NodeOrTokenKind::Node(NodeData {
265                    dialect: node.dialect,
266                    segments,
267                    raw: node.raw.clone(),
268                    source_fixes: node.source_fixes.clone(),
269                    descendant_type_set: node.descendant_type_set.clone(),
270                    raw_segments_with_ancestors: node.raw_segments_with_ancestors.clone(),
271                }),
272                hash: OnceCell::new(),
273            }),
274        }
275    }
276
277    pub fn indent_val(&self) -> i8 {
278        self.value.syntax_kind.indent_val()
279    }
280
281    pub fn can_start_end_non_code(&self) -> bool {
282        matches!(
283            self.value.syntax_kind,
284            SyntaxKind::File | SyntaxKind::Unparsable
285        )
286    }
287
288    pub(crate) fn dialect(&self) -> DialectKind {
289        match &self.value.kind {
290            NodeOrTokenKind::Node(node) => node.dialect,
291            NodeOrTokenKind::Token(_) => todo!(),
292        }
293    }
294
295    pub fn get_start_loc(&self) -> (usize, usize) {
296        match self.get_position_marker() {
297            Some(pos_marker) => pos_marker.working_loc(),
298            None => unreachable!("{self:?} has no PositionMarker"),
299        }
300    }
301
302    pub fn get_end_loc(&self) -> (usize, usize) {
303        match self.get_position_marker() {
304            Some(pos_marker) => pos_marker.working_loc_after(self.raw()),
305            None => {
306                unreachable!("{self:?} has no PositionMarker")
307            }
308        }
309    }
310
311    pub fn is_templated(&self) -> bool {
312        if let Some(pos_marker) = self.get_position_marker() {
313            pos_marker.source_slice.start != pos_marker.source_slice.end && !pos_marker.is_literal()
314        } else {
315            panic!("PosMarker must be set");
316        }
317    }
318
319    pub fn iter_segments(&self, expanding: &SyntaxSet, pass_through: bool) -> Vec<ErasedSegment> {
320        let capacity = if expanding.is_empty() {
321            self.segments().len()
322        } else {
323            0
324        };
325        let mut result = Vec::with_capacity(capacity);
326        for segment in self.segments() {
327            if expanding.contains(segment.get_type()) {
328                let expanding = if pass_through {
329                    expanding
330                } else {
331                    &SyntaxSet::EMPTY
332                };
333                result.append(&mut segment.iter_segments(expanding, false));
334            } else {
335                result.push(segment.clone());
336            }
337        }
338        result
339    }
340
341    pub(crate) fn code_indices(&self) -> Rc<Vec<usize>> {
342        self.value
343            .code_idx
344            .get_or_init(|| {
345                Rc::from(
346                    self.segments()
347                        .iter()
348                        .enumerate()
349                        .filter(|(_, seg)| seg.is_code())
350                        .map(|(idx, _)| idx)
351                        .collect::<Vec<_>>(),
352                )
353            })
354            .clone()
355    }
356
357    pub fn children(
358        &self,
359        seg_types: &'static SyntaxSet,
360    ) -> impl Iterator<Item = &ErasedSegment> + '_ {
361        self.segments()
362            .iter()
363            .filter(move |seg| seg_types.contains(seg.get_type()))
364    }
365
366    pub fn iter_patches(&self, templated_file: &TemplatedFile) -> Vec<FixPatch> {
367        let mut acc = Vec::new();
368
369        let templated_raw = &templated_file.templated_str.as_ref().unwrap()
370            [self.get_position_marker().unwrap().templated_slice.clone()];
371        if self.raw() == templated_raw {
372            acc.extend(self.iter_source_fix_patches(templated_file));
373            return acc;
374        }
375
376        if self.get_position_marker().is_none() {
377            return Vec::new();
378        }
379
380        let pos_marker = self.get_position_marker().unwrap();
381        if pos_marker.is_literal() {
382            acc.extend(self.iter_source_fix_patches(templated_file));
383            acc.push(FixPatch::new(
384                pos_marker.templated_slice.clone(),
385                self.raw().clone(),
386                // SyntaxKind::Literal.into(),
387                pos_marker.source_slice.clone(),
388                templated_file.templated_str.as_ref().unwrap()[pos_marker.templated_slice.clone()]
389                    .to_string(),
390                templated_file.source_str[pos_marker.source_slice.clone()].to_string(),
391            ));
392        } else if self.segments().is_empty() {
393            return acc;
394        } else {
395            let mut segments = self.segments();
396
397            while !segments.is_empty()
398                && matches!(
399                    segments.last().unwrap().get_type(),
400                    SyntaxKind::EndOfFile
401                        | SyntaxKind::Indent
402                        | SyntaxKind::Dedent
403                        | SyntaxKind::Implicit
404                )
405            {
406                segments = &segments[..segments.len() - 1];
407            }
408
409            let pos = self.get_position_marker().unwrap();
410            let mut source_idx = pos.source_slice.start;
411            let mut templated_idx = pos.templated_slice.start;
412            let mut insert_buff = String::new();
413
414            for segment in segments {
415                let pos_marker = segment.get_position_marker().unwrap();
416                if !segment.raw().is_empty() && pos_marker.is_point() {
417                    insert_buff.push_str(segment.raw().as_ref());
418                    continue;
419                }
420
421                let start_diff = pos_marker.templated_slice.start - templated_idx;
422
423                if start_diff > 0 || !insert_buff.is_empty() {
424                    let fixed_raw = std::mem::take(&mut insert_buff);
425                    let raw_segments = segment.get_raw_segments();
426                    let first_segment_pos = raw_segments[0].get_position_marker().unwrap();
427
428                    acc.push(FixPatch::new(
429                        templated_idx..first_segment_pos.templated_slice.start,
430                        fixed_raw.into(),
431                        source_idx..first_segment_pos.source_slice.start,
432                        String::new(),
433                        String::new(),
434                    ));
435                }
436
437                acc.extend(segment.iter_patches(templated_file));
438
439                source_idx = pos_marker.source_slice.end;
440                templated_idx = pos_marker.templated_slice.end;
441            }
442
443            let end_diff = pos.templated_slice.end - templated_idx;
444            if end_diff != 0 || !insert_buff.is_empty() {
445                let source_slice = source_idx..pos.source_slice.end;
446                let templated_slice = templated_idx..pos.templated_slice.end;
447
448                let templated_str = templated_file.templated_str.as_ref().unwrap()
449                    [templated_slice.clone()]
450                .to_owned();
451                let source_str = templated_file.source_str[source_slice.clone()].to_owned();
452
453                acc.push(FixPatch::new(
454                    templated_slice,
455                    insert_buff.into(),
456                    source_slice,
457                    templated_str,
458                    source_str,
459                ));
460            }
461        }
462
463        acc
464    }
465
466    pub fn descendant_type_set(&self) -> &SyntaxSet {
467        match &self.value.kind {
468            NodeOrTokenKind::Node(node) => node.descendant_type_set.get_or_init(|| {
469                self.segments()
470                    .iter()
471                    .flat_map(|segment| {
472                        segment
473                            .descendant_type_set()
474                            .clone()
475                            .union(segment.class_types())
476                    })
477                    .collect()
478            }),
479            NodeOrTokenKind::Token(_) => const { &SyntaxSet::EMPTY },
480        }
481    }
482
483    pub fn is_comment(&self) -> bool {
484        matches!(
485            self.value.syntax_kind,
486            SyntaxKind::Comment | SyntaxKind::InlineComment | SyntaxKind::BlockComment
487        )
488    }
489
490    pub fn is_whitespace(&self) -> bool {
491        matches!(
492            self.value.syntax_kind,
493            SyntaxKind::Whitespace | SyntaxKind::Newline
494        )
495    }
496
497    pub fn is_indent(&self) -> bool {
498        matches!(
499            self.value.syntax_kind,
500            SyntaxKind::Indent | SyntaxKind::Implicit | SyntaxKind::Dedent
501        )
502    }
503
504    pub fn get_position_marker(&self) -> Option<&PositionMarker> {
505        self.value.position_marker.as_ref()
506    }
507
508    pub(crate) fn iter_source_fix_patches(&self, templated_file: &TemplatedFile) -> Vec<FixPatch> {
509        let source_fixes = self.get_source_fixes();
510        let mut patches = Vec::with_capacity(source_fixes.len());
511
512        for source_fix in &source_fixes {
513            patches.push(FixPatch::new(
514                source_fix.templated_slice.clone(),
515                source_fix.edit.clone(),
516                // String::from("source"),
517                source_fix.source_slice.clone(),
518                templated_file.templated_str.clone().unwrap()[source_fix.templated_slice.clone()]
519                    .to_string(),
520                templated_file.source_str[source_fix.source_slice.clone()].to_string(),
521            ));
522        }
523
524        patches
525    }
526
527    pub fn id(&self) -> u32 {
528        self.value.id
529    }
530
531    /// Return any source fixes as list.
532    pub fn get_source_fixes(&self) -> Vec<SourceFix> {
533        match &self.value.kind {
534            NodeOrTokenKind::Node(node) => node.source_fixes.clone(),
535            NodeOrTokenKind::Token(_) => Vec::new(),
536        }
537    }
538
539    pub fn edit(
540        &self,
541        id: u32,
542        raw: Option<String>,
543        _source_fixes: Option<Vec<SourceFix>>,
544    ) -> ErasedSegment {
545        match &self.value.kind {
546            NodeOrTokenKind::Node(_node) => {
547                todo!()
548            }
549            NodeOrTokenKind::Token(token) => {
550                let raw = raw.as_deref().unwrap_or(token.raw.as_ref());
551                SegmentBuilder::token(id, raw, self.value.syntax_kind)
552                    .with_position(self.get_position_marker().unwrap().clone())
553                    .finish()
554            }
555        }
556    }
557
558    pub fn class_types(&self) -> &SyntaxSet {
559        &self.value.class_types
560    }
561
562    pub(crate) fn first_non_whitespace_segment_raw_upper(&self) -> Option<String> {
563        for seg in self.get_raw_segments() {
564            if !seg.raw().is_empty() {
565                return Some(seg.raw().to_uppercase());
566            }
567        }
568        None
569    }
570
571    pub fn is(&self, other: &ErasedSegment) -> bool {
572        Rc::ptr_eq(&self.value, &other.value)
573    }
574
575    pub fn addr(&self) -> usize {
576        Rc::as_ptr(&self.value).addr()
577    }
578
579    pub fn direct_descendant_type_set(&self) -> SyntaxSet {
580        self.segments()
581            .iter()
582            .fold(SyntaxSet::EMPTY, |set, it| set.union(it.class_types()))
583    }
584
585    pub fn is_keyword(&self, p0: &str) -> bool {
586        self.is_type(SyntaxKind::Keyword) && self.raw().eq_ignore_ascii_case(p0)
587    }
588
589    pub fn hash_value(&self) -> u64 {
590        *self.value.hash.get_or_init(|| {
591            let mut hasher = ahash::AHasher::default();
592            self.get_type().hash(&mut hasher);
593            self.raw().hash(&mut hasher);
594
595            if let Some(marker) = &self.get_position_marker() {
596                marker.source_position().hash(&mut hasher);
597            } else {
598                None::<usize>.hash(&mut hasher);
599            }
600
601            hasher.finish()
602        })
603    }
604
605    pub fn deep_clone(&self) -> Self {
606        Self {
607            value: Rc::new(self.value.as_ref().clone()),
608        }
609    }
610
611    #[track_caller]
612    pub(crate) fn get_mut(&mut self) -> &mut NodeOrToken {
613        Rc::get_mut(&mut self.value).unwrap()
614    }
615
616    #[track_caller]
617    pub(crate) fn make_mut(&mut self) -> &mut NodeOrToken {
618        Rc::make_mut(&mut self.value)
619    }
620
621    pub fn reference(&self) -> ObjectReferenceSegment {
622        ObjectReferenceSegment(
623            self.clone(),
624            match self.get_type() {
625                SyntaxKind::TableReference => ObjectReferenceKind::Table,
626                SyntaxKind::WildcardIdentifier => ObjectReferenceKind::WildcardIdentifier,
627                _ => ObjectReferenceKind::Object,
628            },
629        )
630    }
631
632    pub fn recursive_crawl_all(&self, reverse: bool) -> Vec<ErasedSegment> {
633        let mut result = Vec::with_capacity(self.segments().len() + 1);
634
635        if reverse {
636            for seg in self.segments().iter().rev() {
637                result.append(&mut seg.recursive_crawl_all(reverse));
638            }
639            result.push(self.clone());
640        } else {
641            result.push(self.clone());
642            for seg in self.segments() {
643                result.append(&mut seg.recursive_crawl_all(reverse));
644            }
645        }
646
647        result
648    }
649
650    pub fn raw_segments_with_ancestors(&self) -> &[(ErasedSegment, Vec<PathStep>)] {
651        match &self.value.kind {
652            NodeOrTokenKind::Node(node) => node.raw_segments_with_ancestors.get_or_init(|| {
653                let mut buffer: Vec<(ErasedSegment, Vec<PathStep>)> =
654                    Vec::with_capacity(self.segments().len());
655                let code_idxs = self.code_indices();
656
657                for (idx, seg) in self.segments().iter().enumerate() {
658                    let new_step = vec![PathStep {
659                        segment: self.clone(),
660                        idx,
661                        len: self.segments().len(),
662                        code_idxs: code_idxs.clone(),
663                    }];
664
665                    // Use seg.get_segments().is_empty() as a workaround to check if the segment is
666                    // a SyntaxKind::Raw type. In the original Python code, this was achieved
667                    // using seg.is_type(SyntaxKind::Raw). Here, we assume that a SyntaxKind::Raw
668                    // segment is characterized by having no sub-segments.
669
670                    if seg.segments().is_empty() {
671                        buffer.push((seg.clone(), new_step));
672                    } else {
673                        let extended =
674                            seg.raw_segments_with_ancestors()
675                                .iter()
676                                .map(|(raw_seg, stack)| {
677                                    let mut new_step = new_step.clone();
678                                    new_step.extend_from_slice(stack);
679                                    (raw_seg.clone(), new_step)
680                                });
681
682                        buffer.extend(extended);
683                    }
684                }
685
686                buffer
687            }),
688            NodeOrTokenKind::Token(_) => &[],
689        }
690    }
691
692    pub fn path_to(&self, other: &ErasedSegment) -> Vec<PathStep> {
693        let midpoint = other;
694
695        for (idx, seg) in enumerate(self.segments()) {
696            let mut steps = vec![PathStep {
697                segment: self.clone(),
698                idx,
699                len: self.segments().len(),
700                code_idxs: self.code_indices(),
701            }];
702
703            if seg.eq(midpoint) {
704                return steps;
705            }
706
707            let res = seg.path_to(midpoint);
708
709            if !res.is_empty() {
710                steps.extend(res);
711                return steps;
712            }
713        }
714
715        Vec::new()
716    }
717
718    pub fn apply_fixes(
719        &self,
720        fixes: &mut FxHashMap<u32, AnchorEditInfo>,
721    ) -> (ErasedSegment, Vec<ErasedSegment>, Vec<ErasedSegment>) {
722        if fixes.is_empty() || self.segments().is_empty() {
723            return (self.clone(), Vec::new(), Vec::new());
724        }
725
726        let mut seg_buffer = Vec::new();
727        let mut has_applied_fixes = false;
728        let mut _requires_validate = false;
729
730        for seg in self.segments() {
731            // Look for uuid match.
732            // This handles potential positioning ambiguity.
733
734            let Some(mut anchor_info) = fixes.remove(&seg.id()) else {
735                seg_buffer.push(seg.clone());
736                continue;
737            };
738
739            if anchor_info.fixes.len() == 2
740                && anchor_info.fixes[0].edit_type == EditType::CreateAfter
741            {
742                anchor_info.fixes.reverse();
743            }
744
745            let fixes_count = anchor_info.fixes.len();
746            for mut lint_fix in anchor_info.fixes {
747                has_applied_fixes = true;
748
749                // Deletes are easy.
750                if lint_fix.edit_type == EditType::Delete {
751                    // We're just getting rid of this segment.
752                    _requires_validate = true;
753                    // NOTE: We don't add the segment in this case.
754                    continue;
755                }
756
757                // Otherwise it must be a replace or a create.
758                assert!(matches!(
759                    lint_fix.edit_type,
760                    EditType::Replace | EditType::CreateBefore | EditType::CreateAfter
761                ));
762
763                if lint_fix.edit_type == EditType::CreateAfter && fixes_count == 1 {
764                    // In the case of a creation after that is not part
765                    // of a create_before/create_after pair, also add
766                    // this segment before the edit.
767                    seg_buffer.push(seg.clone());
768                }
769
770                let mut consumed_pos = false;
771                for mut s in std::mem::take(&mut lint_fix.edit) {
772                    if lint_fix.edit_type == EditType::Replace
773                        && !consumed_pos
774                        && s.raw() == seg.raw()
775                    {
776                        consumed_pos = true;
777                        s.make_mut()
778                            .set_position_marker(seg.get_position_marker().cloned());
779                    }
780
781                    seg_buffer.push(s);
782                }
783
784                if !(lint_fix.edit_type == EditType::Replace
785                    && lint_fix.edit.len() == 1
786                    && lint_fix.edit[0].class_types() == seg.class_types())
787                {
788                    _requires_validate = true;
789                }
790
791                if lint_fix.edit_type == EditType::CreateBefore {
792                    seg_buffer.push(seg.clone());
793                }
794            }
795        }
796
797        if has_applied_fixes {
798            seg_buffer =
799                position_segments(&seg_buffer, self.get_position_marker().as_ref().unwrap());
800        }
801
802        let seg_queue = seg_buffer;
803        let mut seg_buffer = Vec::new();
804        for seg in seg_queue {
805            let (mid, pre, post) = seg.apply_fixes(fixes);
806
807            seg_buffer.extend(pre);
808            seg_buffer.push(mid);
809            seg_buffer.extend(post);
810        }
811
812        let seg_buffer =
813            position_segments(&seg_buffer, self.get_position_marker().as_ref().unwrap());
814        (self.new(seg_buffer), Vec::new(), Vec::new())
815    }
816}
817
818#[cfg(any(test, feature = "serde"))]
819pub mod serde {
820    use serde::ser::SerializeMap;
821    use serde::{Deserialize, Serialize};
822
823    use crate::parser::segments::ErasedSegment;
824
825    #[derive(Serialize, Deserialize)]
826    #[serde(untagged)]
827    pub enum SerialisedSegmentValue {
828        Single(String),
829        Nested(Vec<TupleSerialisedSegment>),
830    }
831
832    #[derive(Deserialize)]
833    pub struct TupleSerialisedSegment(String, SerialisedSegmentValue);
834
835    impl Serialize for TupleSerialisedSegment {
836        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
837        where
838            S: serde::Serializer,
839        {
840            let mut map = serializer.serialize_map(None)?;
841            map.serialize_key(&self.0)?;
842            map.serialize_value(&self.1)?;
843            map.end()
844        }
845    }
846
847    impl TupleSerialisedSegment {
848        pub fn sinlge(key: String, value: String) -> Self {
849            Self(key, SerialisedSegmentValue::Single(value))
850        }
851
852        pub fn nested(key: String, segments: Vec<TupleSerialisedSegment>) -> Self {
853            Self(key, SerialisedSegmentValue::Nested(segments))
854        }
855    }
856
857    impl ErasedSegment {
858        pub fn to_serialised(&self, code_only: bool, show_raw: bool) -> TupleSerialisedSegment {
859            if show_raw && self.segments().is_empty() {
860                TupleSerialisedSegment::sinlge(
861                    self.get_type().as_str().to_string(),
862                    self.raw().to_string(),
863                )
864            } else if code_only {
865                let segments = self
866                    .segments()
867                    .iter()
868                    .filter(|seg| seg.is_code() && !seg.is_meta())
869                    .map(|seg| seg.to_serialised(code_only, show_raw))
870                    .collect::<Vec<_>>();
871
872                TupleSerialisedSegment::nested(self.get_type().as_str().to_string(), segments)
873            } else {
874                let segments = self
875                    .segments()
876                    .iter()
877                    .map(|seg| seg.to_serialised(code_only, show_raw))
878                    .collect::<Vec<_>>();
879
880                TupleSerialisedSegment::nested(self.get_type().as_str().to_string(), segments)
881            }
882        }
883    }
884}
885
886impl PartialEq for ErasedSegment {
887    fn eq(&self, other: &Self) -> bool {
888        if self.id() == other.id() {
889            return true;
890        }
891
892        let pos_self = self.get_position_marker();
893        let pos_other = other.get_position_marker();
894        if let Some((pos_self, pos_other)) = pos_self.zip(pos_other) {
895            self.get_type() == other.get_type()
896                && pos_self.working_loc() == pos_other.working_loc()
897                && self.raw() == other.raw()
898        } else {
899            false
900        }
901    }
902}
903
904pub fn position_segments(
905    segments: &[ErasedSegment],
906    parent_pos: &PositionMarker,
907) -> Vec<ErasedSegment> {
908    if segments.is_empty() {
909        return Vec::new();
910    }
911
912    let (mut line_no, mut line_pos) = { (parent_pos.working_line_no, parent_pos.working_line_pos) };
913
914    let mut segment_buffer: Vec<ErasedSegment> = Vec::new();
915    for (idx, segment) in enumerate(segments) {
916        let old_position = segment.get_position_marker();
917
918        let mut new_position = match old_position {
919            Some(pos_marker) => pos_marker.clone(),
920            None => {
921                let start_point = if idx > 0 {
922                    let prev_seg = segment_buffer[idx - 1].clone();
923                    Some(prev_seg.get_position_marker().unwrap().end_point_marker())
924                } else {
925                    Some(parent_pos.start_point_marker())
926                };
927
928                let mut end_point = None;
929                for fwd_seg in &segments[idx + 1..] {
930                    if fwd_seg.get_position_marker().is_some() {
931                        end_point = Some(
932                            fwd_seg.get_raw_segments()[0]
933                                .get_position_marker()
934                                .unwrap()
935                                .start_point_marker(),
936                        );
937                        break;
938                    }
939                }
940
941                if let Some((start_point, end_point)) = start_point
942                    .as_ref()
943                    .zip(end_point.as_ref())
944                    .filter(|(start_point, end_point)| start_point != end_point)
945                {
946                    PositionMarker::from_points(start_point, end_point)
947                } else if let Some(start_point) = start_point.as_ref() {
948                    start_point.clone()
949                } else if let Some(end_point) = end_point.as_ref() {
950                    end_point.clone()
951                } else {
952                    unimplemented!("Unable to position new segment")
953                }
954            }
955        };
956
957        new_position = new_position.with_working_position(line_no, line_pos);
958        (line_no, line_pos) = PositionMarker::infer_next_position(segment.raw(), line_no, line_pos);
959
960        let mut new_seg = if !segment.segments().is_empty() && old_position != Some(&new_position) {
961            let child_segments = position_segments(segment.segments(), &new_position);
962            segment.change_segments(child_segments)
963        } else {
964            segment.deep_clone()
965        };
966
967        new_seg.get_mut().set_position_marker(new_position.into());
968        segment_buffer.push(new_seg);
969    }
970
971    segment_buffer
972}
973
974#[derive(Debug, Clone)]
975pub struct NodeOrToken {
976    id: u32,
977    syntax_kind: SyntaxKind,
978    class_types: SyntaxSet,
979    position_marker: Option<PositionMarker>,
980    kind: NodeOrTokenKind,
981    code_idx: OnceCell<Rc<Vec<usize>>>,
982    hash: OnceCell<u64>,
983}
984
985#[derive(Debug, Clone)]
986pub enum NodeOrTokenKind {
987    Node(NodeData),
988    Token(TokenData),
989}
990
991impl NodeOrToken {
992    pub fn set_position_marker(&mut self, position_marker: Option<PositionMarker>) {
993        self.position_marker = position_marker;
994    }
995
996    pub fn set_id(&mut self, id: u32) {
997        self.id = id;
998    }
999}
1000
1001#[derive(Debug, Clone)]
1002pub struct NodeData {
1003    dialect: DialectKind,
1004    segments: Vec<ErasedSegment>,
1005    raw: OnceCell<SmolStr>,
1006    source_fixes: Vec<SourceFix>,
1007    descendant_type_set: OnceCell<SyntaxSet>,
1008    raw_segments_with_ancestors: OnceCell<Vec<(ErasedSegment, Vec<PathStep>)>>,
1009}
1010
1011#[derive(Debug, Clone, PartialEq)]
1012pub struct TokenData {
1013    raw: SmolStr,
1014}
1015
1016#[track_caller]
1017pub fn pos_marker(segments: &[ErasedSegment]) -> PositionMarker {
1018    let markers = segments.iter().filter_map(|seg| seg.get_position_marker());
1019
1020    PositionMarker::from_child_markers(markers)
1021}
1022
1023#[derive(Debug, Clone)]
1024pub struct PathStep {
1025    pub segment: ErasedSegment,
1026    pub idx: usize,
1027    pub len: usize,
1028    pub code_idxs: Rc<Vec<usize>>,
1029}
1030
1031fn class_types(syntax_kind: SyntaxKind) -> SyntaxSet {
1032    match syntax_kind {
1033        SyntaxKind::ColumnReference => SyntaxSet::new(&[SyntaxKind::ObjectReference, syntax_kind]),
1034        SyntaxKind::WildcardIdentifier => {
1035            SyntaxSet::new(&[SyntaxKind::WildcardIdentifier, SyntaxKind::ObjectReference])
1036        }
1037        SyntaxKind::TableReference => SyntaxSet::new(&[SyntaxKind::ObjectReference, syntax_kind]),
1038        _ => SyntaxSet::single(syntax_kind),
1039    }
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use crate::lint_fix::LintFix;
1046    use crate::linter::compute_anchor_edit_info;
1047    use crate::parser::segments::test_functions::{raw_seg, raw_segments};
1048
1049    #[test]
1050    /// Test comparison of raw segments.
1051    fn test_parser_base_segments_raw_compare() {
1052        let template: TemplatedFile = "foobar".into();
1053        let rs1 = SegmentBuilder::token(0, "foobar", SyntaxKind::Word)
1054            .with_position(PositionMarker::new(
1055                0..6,
1056                0..6,
1057                template.clone(),
1058                None,
1059                None,
1060            ))
1061            .finish();
1062        let rs2 = SegmentBuilder::token(0, "foobar", SyntaxKind::Word)
1063            .with_position(PositionMarker::new(
1064                0..6,
1065                0..6,
1066                template.clone(),
1067                None,
1068                None,
1069            ))
1070            .finish();
1071
1072        assert_eq!(rs1, rs2)
1073    }
1074
1075    #[test]
1076    // TODO Implement
1077    /// Test raw segments behave as expected.
1078    fn test_parser_base_segments_raw() {
1079        let raw_seg = raw_seg();
1080
1081        assert_eq!(raw_seg.raw(), "foobar");
1082    }
1083
1084    #[test]
1085    /// Test BaseSegment.compute_anchor_edit_info().
1086    fn test_parser_base_segments_compute_anchor_edit_info() {
1087        let raw_segs = raw_segments();
1088        let tables = Tables::default();
1089
1090        // Construct a fix buffer, intentionally with:
1091        // - one duplicate.
1092        // - two different incompatible fixes on the same segment.
1093        let fixes = vec![
1094            LintFix::replace(
1095                raw_segs[0].clone(),
1096                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1097                None,
1098            ),
1099            LintFix::replace(
1100                raw_segs[0].clone(),
1101                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1102                None,
1103            ),
1104            LintFix::replace(
1105                raw_segs[0].clone(),
1106                vec![raw_segs[0].edit(tables.next_id(), Some("b".to_string()), None)],
1107                None,
1108            ),
1109        ];
1110
1111        let mut anchor_edit_info = Default::default();
1112        compute_anchor_edit_info(&mut anchor_edit_info, fixes);
1113
1114        // Check the target segment is the only key we have.
1115        assert_eq!(
1116            anchor_edit_info.keys().collect::<Vec<_>>(),
1117            vec![&raw_segs[0].id()]
1118        );
1119
1120        let anchor_info = anchor_edit_info.get(&raw_segs[0].id()).unwrap();
1121
1122        // Check that the duplicate as been deduplicated i.e. this isn't 3.
1123        assert_eq!(anchor_info.replace, 2);
1124
1125        // Check the fixes themselves.
1126        //   Note: There's no duplicated first fix.
1127        assert_eq!(
1128            anchor_info.fixes[0],
1129            LintFix::replace(
1130                raw_segs[0].clone(),
1131                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1132                None,
1133            )
1134        );
1135        assert_eq!(
1136            anchor_info.fixes[1],
1137            LintFix::replace(
1138                raw_segs[0].clone(),
1139                vec![raw_segs[0].edit(tables.next_id(), Some("b".to_string()), None)],
1140                None,
1141            )
1142        );
1143
1144        // Check the first replace
1145        assert_eq!(
1146            anchor_info.fixes[anchor_info.first_replace.unwrap()],
1147            LintFix::replace(
1148                raw_segs[0].clone(),
1149                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1150                None,
1151            )
1152        );
1153    }
1154}