typst_syntax/
node.rs

1use std::fmt::{self, Debug, Display, Formatter};
2use std::ops::{Deref, Range};
3use std::rc::Rc;
4use std::sync::Arc;
5
6use ecow::{eco_format, eco_vec, EcoString, EcoVec};
7
8use crate::ast::AstNode;
9use crate::{FileId, Span, SyntaxKind};
10
11/// A node in the untyped syntax tree.
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct SyntaxNode(Repr);
14
15/// The three internal representations.
16#[derive(Clone, Eq, PartialEq, Hash)]
17enum Repr {
18    /// A leaf node.
19    Leaf(LeafNode),
20    /// A reference-counted inner node.
21    Inner(Arc<InnerNode>),
22    /// An error node.
23    Error(Arc<ErrorNode>),
24}
25
26impl SyntaxNode {
27    /// Create a new leaf node.
28    pub fn leaf(kind: SyntaxKind, text: impl Into<EcoString>) -> Self {
29        Self(Repr::Leaf(LeafNode::new(kind, text)))
30    }
31
32    /// Create a new inner node with children.
33    pub fn inner(kind: SyntaxKind, children: Vec<SyntaxNode>) -> Self {
34        Self(Repr::Inner(Arc::new(InnerNode::new(kind, children))))
35    }
36
37    /// Create a new error node.
38    pub fn error(error: SyntaxError, text: impl Into<EcoString>) -> Self {
39        Self(Repr::Error(Arc::new(ErrorNode::new(error, text))))
40    }
41
42    /// Create a dummy node of the given kind.
43    ///
44    /// Panics if `kind` is `SyntaxKind::Error`.
45    #[track_caller]
46    pub const fn placeholder(kind: SyntaxKind) -> Self {
47        if matches!(kind, SyntaxKind::Error) {
48            panic!("cannot create error placeholder");
49        }
50        Self(Repr::Leaf(LeafNode {
51            kind,
52            text: EcoString::new(),
53            span: Span::detached(),
54        }))
55    }
56
57    /// The type of the node.
58    pub fn kind(&self) -> SyntaxKind {
59        match &self.0 {
60            Repr::Leaf(leaf) => leaf.kind,
61            Repr::Inner(inner) => inner.kind,
62            Repr::Error(_) => SyntaxKind::Error,
63        }
64    }
65
66    /// Return `true` if the length is 0.
67    pub fn is_empty(&self) -> bool {
68        self.len() == 0
69    }
70
71    /// The byte length of the node in the source text.
72    pub fn len(&self) -> usize {
73        match &self.0 {
74            Repr::Leaf(leaf) => leaf.len(),
75            Repr::Inner(inner) => inner.len,
76            Repr::Error(node) => node.len(),
77        }
78    }
79
80    /// The span of the node.
81    pub fn span(&self) -> Span {
82        match &self.0 {
83            Repr::Leaf(leaf) => leaf.span,
84            Repr::Inner(inner) => inner.span,
85            Repr::Error(node) => node.error.span,
86        }
87    }
88
89    /// The text of the node if it is a leaf or error node.
90    ///
91    /// Returns the empty string if this is an inner node.
92    pub fn text(&self) -> &EcoString {
93        static EMPTY: EcoString = EcoString::new();
94        match &self.0 {
95            Repr::Leaf(leaf) => &leaf.text,
96            Repr::Inner(_) => &EMPTY,
97            Repr::Error(node) => &node.text,
98        }
99    }
100
101    /// Extract the text from the node.
102    ///
103    /// Builds the string if this is an inner node.
104    pub fn into_text(self) -> EcoString {
105        match self.0 {
106            Repr::Leaf(leaf) => leaf.text,
107            Repr::Inner(inner) => {
108                inner.children.iter().cloned().map(Self::into_text).collect()
109            }
110            Repr::Error(node) => node.text.clone(),
111        }
112    }
113
114    /// The node's children.
115    pub fn children(&self) -> std::slice::Iter<'_, SyntaxNode> {
116        match &self.0 {
117            Repr::Leaf(_) | Repr::Error(_) => [].iter(),
118            Repr::Inner(inner) => inner.children.iter(),
119        }
120    }
121
122    /// Whether the node can be cast to the given AST node.
123    pub fn is<'a, T: AstNode<'a>>(&'a self) -> bool {
124        self.cast::<T>().is_some()
125    }
126
127    /// Try to convert the node to a typed AST node.
128    pub fn cast<'a, T: AstNode<'a>>(&'a self) -> Option<T> {
129        T::from_untyped(self)
130    }
131
132    /// Cast the first child that can cast to the AST type `T`.
133    pub fn cast_first_match<'a, T: AstNode<'a>>(&'a self) -> Option<T> {
134        self.children().find_map(Self::cast)
135    }
136
137    /// Cast the last child that can cast to the AST type `T`.
138    pub fn cast_last_match<'a, T: AstNode<'a>>(&'a self) -> Option<T> {
139        self.children().rev().find_map(Self::cast)
140    }
141
142    /// Whether the node or its children contain an error.
143    pub fn erroneous(&self) -> bool {
144        match &self.0 {
145            Repr::Leaf(_) => false,
146            Repr::Inner(inner) => inner.erroneous,
147            Repr::Error(_) => true,
148        }
149    }
150
151    /// The error messages for this node and its descendants.
152    pub fn errors(&self) -> Vec<SyntaxError> {
153        if !self.erroneous() {
154            return vec![];
155        }
156
157        if let Repr::Error(node) = &self.0 {
158            vec![node.error.clone()]
159        } else {
160            self.children()
161                .filter(|node| node.erroneous())
162                .flat_map(|node| node.errors())
163                .collect()
164        }
165    }
166
167    /// Add a user-presentable hint if this is an error node.
168    pub fn hint(&mut self, hint: impl Into<EcoString>) {
169        if let Repr::Error(node) = &mut self.0 {
170            Arc::make_mut(node).hint(hint);
171        }
172    }
173
174    /// Set a synthetic span for the node and all its descendants.
175    pub fn synthesize(&mut self, span: Span) {
176        match &mut self.0 {
177            Repr::Leaf(leaf) => leaf.span = span,
178            Repr::Inner(inner) => Arc::make_mut(inner).synthesize(span),
179            Repr::Error(node) => Arc::make_mut(node).error.span = span,
180        }
181    }
182
183    /// Whether the two syntax nodes are the same apart from spans.
184    pub fn spanless_eq(&self, other: &Self) -> bool {
185        match (&self.0, &other.0) {
186            (Repr::Leaf(a), Repr::Leaf(b)) => a.spanless_eq(b),
187            (Repr::Inner(a), Repr::Inner(b)) => a.spanless_eq(b),
188            (Repr::Error(a), Repr::Error(b)) => a.spanless_eq(b),
189            _ => false,
190        }
191    }
192}
193
194impl SyntaxNode {
195    /// Convert the child to another kind.
196    ///
197    /// Don't use this for converting to an error!
198    #[track_caller]
199    pub(super) fn convert_to_kind(&mut self, kind: SyntaxKind) {
200        debug_assert!(!kind.is_error());
201        match &mut self.0 {
202            Repr::Leaf(leaf) => leaf.kind = kind,
203            Repr::Inner(inner) => Arc::make_mut(inner).kind = kind,
204            Repr::Error(_) => panic!("cannot convert error"),
205        }
206    }
207
208    /// Convert the child to an error, if it isn't already one.
209    pub(super) fn convert_to_error(&mut self, message: impl Into<EcoString>) {
210        if !self.kind().is_error() {
211            let text = std::mem::take(self).into_text();
212            *self = SyntaxNode::error(SyntaxError::new(message), text);
213        }
214    }
215
216    /// Convert the child to an error stating that the given thing was
217    /// expected, but the current kind was found.
218    pub(super) fn expected(&mut self, expected: &str) {
219        let kind = self.kind();
220        self.convert_to_error(eco_format!("expected {expected}, found {}", kind.name()));
221        if kind.is_keyword() && matches!(expected, "identifier" | "pattern") {
222            self.hint(eco_format!(
223                "keyword `{text}` is not allowed as an identifier; try `{text}_` instead",
224                text = self.text(),
225            ));
226        }
227    }
228
229    /// Convert the child to an error stating it was unexpected.
230    pub(super) fn unexpected(&mut self) {
231        self.convert_to_error(eco_format!("unexpected {}", self.kind().name()));
232    }
233
234    /// Assign spans to each node.
235    pub(super) fn numberize(
236        &mut self,
237        id: FileId,
238        within: Range<u64>,
239    ) -> NumberingResult {
240        if within.start >= within.end {
241            return Err(Unnumberable);
242        }
243
244        let mid = Span::from_number(id, (within.start + within.end) / 2).unwrap();
245        match &mut self.0 {
246            Repr::Leaf(leaf) => leaf.span = mid,
247            Repr::Inner(inner) => Arc::make_mut(inner).numberize(id, None, within)?,
248            Repr::Error(node) => Arc::make_mut(node).error.span = mid,
249        }
250
251        Ok(())
252    }
253
254    /// Whether this is a leaf node.
255    pub(super) fn is_leaf(&self) -> bool {
256        matches!(self.0, Repr::Leaf(_))
257    }
258
259    /// The number of descendants, including the node itself.
260    pub(super) fn descendants(&self) -> usize {
261        match &self.0 {
262            Repr::Leaf(_) | Repr::Error(_) => 1,
263            Repr::Inner(inner) => inner.descendants,
264        }
265    }
266
267    /// The node's children, mutably.
268    pub(super) fn children_mut(&mut self) -> &mut [SyntaxNode] {
269        match &mut self.0 {
270            Repr::Leaf(_) | Repr::Error(_) => &mut [],
271            Repr::Inner(inner) => &mut Arc::make_mut(inner).children,
272        }
273    }
274
275    /// Replaces a range of children with a replacement.
276    ///
277    /// May have mutated the children if it returns `Err(_)`.
278    pub(super) fn replace_children(
279        &mut self,
280        range: Range<usize>,
281        replacement: Vec<SyntaxNode>,
282    ) -> NumberingResult {
283        if let Repr::Inner(inner) = &mut self.0 {
284            Arc::make_mut(inner).replace_children(range, replacement)?;
285        }
286        Ok(())
287    }
288
289    /// Update this node after changes were made to one of its children.
290    pub(super) fn update_parent(
291        &mut self,
292        prev_len: usize,
293        new_len: usize,
294        prev_descendants: usize,
295        new_descendants: usize,
296    ) {
297        if let Repr::Inner(inner) = &mut self.0 {
298            Arc::make_mut(inner).update_parent(
299                prev_len,
300                new_len,
301                prev_descendants,
302                new_descendants,
303            );
304        }
305    }
306
307    /// The upper bound of assigned numbers in this subtree.
308    pub(super) fn upper(&self) -> u64 {
309        match &self.0 {
310            Repr::Leaf(leaf) => leaf.span.number() + 1,
311            Repr::Inner(inner) => inner.upper,
312            Repr::Error(node) => node.error.span.number() + 1,
313        }
314    }
315}
316
317impl Debug for SyntaxNode {
318    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
319        match &self.0 {
320            Repr::Leaf(leaf) => leaf.fmt(f),
321            Repr::Inner(inner) => inner.fmt(f),
322            Repr::Error(node) => node.fmt(f),
323        }
324    }
325}
326
327impl Default for SyntaxNode {
328    fn default() -> Self {
329        Self::leaf(SyntaxKind::End, EcoString::new())
330    }
331}
332
333/// A leaf node in the untyped syntax tree.
334#[derive(Clone, Eq, PartialEq, Hash)]
335struct LeafNode {
336    /// What kind of node this is (each kind would have its own struct in a
337    /// strongly typed AST).
338    kind: SyntaxKind,
339    /// The source text of the node.
340    text: EcoString,
341    /// The node's span.
342    span: Span,
343}
344
345impl LeafNode {
346    /// Create a new leaf node.
347    #[track_caller]
348    fn new(kind: SyntaxKind, text: impl Into<EcoString>) -> Self {
349        debug_assert!(!kind.is_error());
350        Self { kind, text: text.into(), span: Span::detached() }
351    }
352
353    /// The byte length of the node in the source text.
354    fn len(&self) -> usize {
355        self.text.len()
356    }
357
358    /// Whether the two leaf nodes are the same apart from spans.
359    fn spanless_eq(&self, other: &Self) -> bool {
360        self.kind == other.kind && self.text == other.text
361    }
362}
363
364impl Debug for LeafNode {
365    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
366        write!(f, "{:?}: {:?}", self.kind, self.text)
367    }
368}
369
370/// An inner node in the untyped syntax tree.
371#[derive(Clone, Eq, PartialEq, Hash)]
372struct InnerNode {
373    /// What kind of node this is (each kind would have its own struct in a
374    /// strongly typed AST).
375    kind: SyntaxKind,
376    /// The byte length of the node in the source.
377    len: usize,
378    /// The node's span.
379    span: Span,
380    /// The number of nodes in the whole subtree, including this node.
381    descendants: usize,
382    /// Whether this node or any of its children are erroneous.
383    erroneous: bool,
384    /// The upper bound of this node's numbering range.
385    upper: u64,
386    /// This node's children, losslessly make up this node.
387    children: Vec<SyntaxNode>,
388}
389
390impl InnerNode {
391    /// Create a new inner node with the given kind and children.
392    #[track_caller]
393    fn new(kind: SyntaxKind, children: Vec<SyntaxNode>) -> Self {
394        debug_assert!(!kind.is_error());
395
396        let mut len = 0;
397        let mut descendants = 1;
398        let mut erroneous = false;
399
400        for child in &children {
401            len += child.len();
402            descendants += child.descendants();
403            erroneous |= child.erroneous();
404        }
405
406        Self {
407            kind,
408            len,
409            span: Span::detached(),
410            descendants,
411            erroneous,
412            upper: 0,
413            children,
414        }
415    }
416
417    /// Set a synthetic span for the node and all its descendants.
418    fn synthesize(&mut self, span: Span) {
419        self.span = span;
420        self.upper = span.number();
421        for child in &mut self.children {
422            child.synthesize(span);
423        }
424    }
425
426    /// Assign span numbers `within` an interval to this node's subtree or just
427    /// a `range` of its children.
428    fn numberize(
429        &mut self,
430        id: FileId,
431        range: Option<Range<usize>>,
432        within: Range<u64>,
433    ) -> NumberingResult {
434        // Determine how many nodes we will number.
435        let descendants = match &range {
436            Some(range) if range.is_empty() => return Ok(()),
437            Some(range) => self.children[range.clone()]
438                .iter()
439                .map(SyntaxNode::descendants)
440                .sum::<usize>(),
441            None => self.descendants,
442        };
443
444        // Determine the distance between two neighbouring assigned numbers. If
445        // possible, we try to fit all numbers into the left half of `within`
446        // so that there is space for future insertions.
447        let space = within.end - within.start;
448        let mut stride = space / (2 * descendants as u64);
449        if stride == 0 {
450            stride = space / self.descendants as u64;
451            if stride == 0 {
452                return Err(Unnumberable);
453            }
454        }
455
456        // Number the node itself.
457        let mut start = within.start;
458        if range.is_none() {
459            let end = start + stride;
460            self.span = Span::from_number(id, (start + end) / 2).unwrap();
461            self.upper = within.end;
462            start = end;
463        }
464
465        // Number the children.
466        let len = self.children.len();
467        for child in &mut self.children[range.unwrap_or(0..len)] {
468            let end = start + child.descendants() as u64 * stride;
469            child.numberize(id, start..end)?;
470            start = end;
471        }
472
473        Ok(())
474    }
475
476    /// Whether the two inner nodes are the same apart from spans.
477    fn spanless_eq(&self, other: &Self) -> bool {
478        self.kind == other.kind
479            && self.len == other.len
480            && self.descendants == other.descendants
481            && self.erroneous == other.erroneous
482            && self.children.len() == other.children.len()
483            && self
484                .children
485                .iter()
486                .zip(&other.children)
487                .all(|(a, b)| a.spanless_eq(b))
488    }
489
490    /// Replaces a range of children with a replacement.
491    ///
492    /// May have mutated the children if it returns `Err(_)`.
493    fn replace_children(
494        &mut self,
495        mut range: Range<usize>,
496        replacement: Vec<SyntaxNode>,
497    ) -> NumberingResult {
498        let Some(id) = self.span.id() else { return Err(Unnumberable) };
499        let mut replacement_range = 0..replacement.len();
500
501        // Trim off common prefix.
502        while range.start < range.end
503            && replacement_range.start < replacement_range.end
504            && self.children[range.start]
505                .spanless_eq(&replacement[replacement_range.start])
506        {
507            range.start += 1;
508            replacement_range.start += 1;
509        }
510
511        // Trim off common suffix.
512        while range.start < range.end
513            && replacement_range.start < replacement_range.end
514            && self.children[range.end - 1]
515                .spanless_eq(&replacement[replacement_range.end - 1])
516        {
517            range.end -= 1;
518            replacement_range.end -= 1;
519        }
520
521        let mut replacement_vec = replacement;
522        let replacement = &replacement_vec[replacement_range.clone()];
523        let superseded = &self.children[range.clone()];
524
525        // Compute the new byte length.
526        self.len = self.len + replacement.iter().map(SyntaxNode::len).sum::<usize>()
527            - superseded.iter().map(SyntaxNode::len).sum::<usize>();
528
529        // Compute the new number of descendants.
530        self.descendants = self.descendants
531            + replacement.iter().map(SyntaxNode::descendants).sum::<usize>()
532            - superseded.iter().map(SyntaxNode::descendants).sum::<usize>();
533
534        // Determine whether we're still erroneous after the replacement. That's
535        // the case if
536        // - any of the new nodes is erroneous,
537        // - or if we were erroneous before due to a non-superseded node.
538        self.erroneous = replacement.iter().any(SyntaxNode::erroneous)
539            || (self.erroneous
540                && (self.children[..range.start].iter().any(SyntaxNode::erroneous))
541                || self.children[range.end..].iter().any(SyntaxNode::erroneous));
542
543        // Perform the replacement.
544        self.children
545            .splice(range.clone(), replacement_vec.drain(replacement_range.clone()));
546        range.end = range.start + replacement_range.len();
547
548        // Renumber the new children. Retries until it works, taking
549        // exponentially more children into account.
550        let mut left = 0;
551        let mut right = 0;
552        let max_left = range.start;
553        let max_right = self.children.len() - range.end;
554        loop {
555            let renumber = range.start - left..range.end + right;
556
557            // The minimum assignable number is either
558            // - the upper bound of the node right before the to-be-renumbered
559            //   children,
560            // - or this inner node's span number plus one if renumbering starts
561            //   at the first child.
562            let start_number = renumber
563                .start
564                .checked_sub(1)
565                .and_then(|i| self.children.get(i))
566                .map_or(self.span.number() + 1, |child| child.upper());
567
568            // The upper bound for renumbering is either
569            // - the span number of the first child after the to-be-renumbered
570            //   children,
571            // - or this node's upper bound if renumbering ends behind the last
572            //   child.
573            let end_number = self
574                .children
575                .get(renumber.end)
576                .map_or(self.upper, |next| next.span().number());
577
578            // Try to renumber.
579            let within = start_number..end_number;
580            if self.numberize(id, Some(renumber), within).is_ok() {
581                return Ok(());
582            }
583
584            // If it didn't even work with all children, we give up.
585            if left == max_left && right == max_right {
586                return Err(Unnumberable);
587            }
588
589            // Exponential expansion to both sides.
590            left = (left + 1).next_power_of_two().min(max_left);
591            right = (right + 1).next_power_of_two().min(max_right);
592        }
593    }
594
595    /// Update this node after changes were made to one of its children.
596    fn update_parent(
597        &mut self,
598        prev_len: usize,
599        new_len: usize,
600        prev_descendants: usize,
601        new_descendants: usize,
602    ) {
603        self.len = self.len + new_len - prev_len;
604        self.descendants = self.descendants + new_descendants - prev_descendants;
605        self.erroneous = self.children.iter().any(SyntaxNode::erroneous);
606    }
607}
608
609impl Debug for InnerNode {
610    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
611        write!(f, "{:?}: {}", self.kind, self.len)?;
612        if !self.children.is_empty() {
613            f.write_str(" ")?;
614            f.debug_list().entries(&self.children).finish()?;
615        }
616        Ok(())
617    }
618}
619
620/// An error node in the untyped syntax tree.
621#[derive(Clone, Eq, PartialEq, Hash)]
622struct ErrorNode {
623    /// The source text of the node.
624    text: EcoString,
625    /// The syntax error.
626    error: SyntaxError,
627}
628
629impl ErrorNode {
630    /// Create new error node.
631    fn new(error: SyntaxError, text: impl Into<EcoString>) -> Self {
632        Self { text: text.into(), error }
633    }
634
635    /// The byte length of the node in the source text.
636    fn len(&self) -> usize {
637        self.text.len()
638    }
639
640    /// Add a user-presentable hint to this error node.
641    fn hint(&mut self, hint: impl Into<EcoString>) {
642        self.error.hints.push(hint.into());
643    }
644
645    /// Whether the two leaf nodes are the same apart from spans.
646    fn spanless_eq(&self, other: &Self) -> bool {
647        self.text == other.text && self.error.spanless_eq(&other.error)
648    }
649}
650
651impl Debug for ErrorNode {
652    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
653        write!(f, "Error: {:?} ({})", self.text, self.error.message)
654    }
655}
656
657/// A syntactical error.
658#[derive(Debug, Clone, Eq, PartialEq, Hash)]
659pub struct SyntaxError {
660    /// The node's span.
661    pub span: Span,
662    /// The error message.
663    pub message: EcoString,
664    /// Additional hints to the user, indicating how this error could be avoided
665    /// or worked around.
666    pub hints: EcoVec<EcoString>,
667}
668
669impl SyntaxError {
670    /// Create a new detached syntax error.
671    pub fn new(message: impl Into<EcoString>) -> Self {
672        Self {
673            span: Span::detached(),
674            message: message.into(),
675            hints: eco_vec![],
676        }
677    }
678
679    /// Whether the two errors are the same apart from spans.
680    fn spanless_eq(&self, other: &Self) -> bool {
681        self.message == other.message && self.hints == other.hints
682    }
683}
684
685/// A syntax node in a context.
686///
687/// Knows its exact offset in the file and provides access to its
688/// children, parent and siblings.
689///
690/// **Note that all sibling and leaf accessors skip over trivia!**
691#[derive(Clone)]
692pub struct LinkedNode<'a> {
693    node: &'a SyntaxNode,
694    parent: Option<Rc<Self>>,
695    index: usize,
696    offset: usize,
697}
698
699impl<'a> LinkedNode<'a> {
700    /// Start a new traversal at a root node.
701    pub fn new(root: &'a SyntaxNode) -> Self {
702        Self { node: root, parent: None, index: 0, offset: 0 }
703    }
704
705    /// Get the contained syntax node.
706    pub fn get(&self) -> &'a SyntaxNode {
707        self.node
708    }
709
710    /// The index of this node in its parent's children list.
711    pub fn index(&self) -> usize {
712        self.index
713    }
714
715    /// The absolute byte offset of this node in the source file.
716    pub fn offset(&self) -> usize {
717        self.offset
718    }
719
720    /// The byte range of this node in the source file.
721    pub fn range(&self) -> Range<usize> {
722        self.offset..self.offset + self.node.len()
723    }
724
725    /// An iterator over this node's children.
726    pub fn children(&self) -> LinkedChildren<'a> {
727        LinkedChildren {
728            parent: Rc::new(self.clone()),
729            iter: self.node.children().enumerate(),
730            front: self.offset,
731            back: self.offset + self.len(),
732        }
733    }
734
735    /// Find a descendant with the given span.
736    pub fn find(&self, span: Span) -> Option<LinkedNode<'a>> {
737        if self.span() == span {
738            return Some(self.clone());
739        }
740
741        if let Repr::Inner(inner) = &self.0 {
742            // The parent of a subtree has a smaller span number than all of its
743            // descendants. Therefore, we can bail out early if the target span's
744            // number is smaller than our number.
745            if span.number() < inner.span.number() {
746                return None;
747            }
748
749            let mut children = self.children().peekable();
750            while let Some(child) = children.next() {
751                // Every node in this child's subtree has a smaller span number than
752                // the next sibling. Therefore we only need to recurse if the next
753                // sibling's span number is larger than the target span's number.
754                if children
755                    .peek()
756                    .map_or(true, |next| next.span().number() > span.number())
757                {
758                    if let Some(found) = child.find(span) {
759                        return Some(found);
760                    }
761                }
762            }
763        }
764
765        None
766    }
767}
768
769/// Access to parents and siblings.
770impl LinkedNode<'_> {
771    /// Get this node's parent.
772    pub fn parent(&self) -> Option<&Self> {
773        self.parent.as_deref()
774    }
775
776    /// Get the first previous non-trivia sibling node.
777    pub fn prev_sibling(&self) -> Option<Self> {
778        let parent = self.parent()?;
779        let index = self.index.checked_sub(1)?;
780        let node = parent.node.children().nth(index)?;
781        let offset = self.offset - node.len();
782        let prev = Self { node, parent: self.parent.clone(), index, offset };
783        if prev.kind().is_trivia() {
784            prev.prev_sibling()
785        } else {
786            Some(prev)
787        }
788    }
789
790    /// Get the next non-trivia sibling node.
791    pub fn next_sibling(&self) -> Option<Self> {
792        let parent = self.parent()?;
793        let index = self.index.checked_add(1)?;
794        let node = parent.node.children().nth(index)?;
795        let offset = self.offset + self.node.len();
796        let next = Self { node, parent: self.parent.clone(), index, offset };
797        if next.kind().is_trivia() {
798            next.next_sibling()
799        } else {
800            Some(next)
801        }
802    }
803
804    /// Get the kind of this node's parent.
805    pub fn parent_kind(&self) -> Option<SyntaxKind> {
806        Some(self.parent()?.node.kind())
807    }
808
809    /// Get the kind of this node's first previous non-trivia sibling.
810    pub fn prev_sibling_kind(&self) -> Option<SyntaxKind> {
811        Some(self.prev_sibling()?.node.kind())
812    }
813
814    /// Get the kind of this node's next non-trivia sibling.
815    pub fn next_sibling_kind(&self) -> Option<SyntaxKind> {
816        Some(self.next_sibling()?.node.kind())
817    }
818}
819
820/// Indicates whether the cursor is before the related byte index, or after.
821#[derive(Debug, Clone)]
822pub enum Side {
823    Before,
824    After,
825}
826
827/// Access to leaves.
828impl LinkedNode<'_> {
829    /// Get the rightmost non-trivia leaf before this node.
830    pub fn prev_leaf(&self) -> Option<Self> {
831        let mut node = self.clone();
832        while let Some(prev) = node.prev_sibling() {
833            if let Some(leaf) = prev.rightmost_leaf() {
834                return Some(leaf);
835            }
836            node = prev;
837        }
838        self.parent()?.prev_leaf()
839    }
840
841    /// Find the leftmost contained non-trivia leaf.
842    pub fn leftmost_leaf(&self) -> Option<Self> {
843        if self.is_leaf() && !self.kind().is_trivia() && !self.kind().is_error() {
844            return Some(self.clone());
845        }
846
847        for child in self.children() {
848            if let Some(leaf) = child.leftmost_leaf() {
849                return Some(leaf);
850            }
851        }
852
853        None
854    }
855
856    /// Get the leaf immediately before the specified byte offset.
857    fn leaf_before(&self, cursor: usize) -> Option<Self> {
858        if self.node.children().len() == 0 && cursor <= self.offset + self.len() {
859            return Some(self.clone());
860        }
861
862        let mut offset = self.offset;
863        let count = self.node.children().len();
864        for (i, child) in self.children().enumerate() {
865            let len = child.len();
866            if (offset < cursor && cursor <= offset + len)
867                || (offset == cursor && i + 1 == count)
868            {
869                return child.leaf_before(cursor);
870            }
871            offset += len;
872        }
873
874        None
875    }
876
877    /// Get the leaf after the specified byte offset.
878    fn leaf_after(&self, cursor: usize) -> Option<Self> {
879        if self.node.children().len() == 0 && cursor < self.offset + self.len() {
880            return Some(self.clone());
881        }
882
883        let mut offset = self.offset;
884        for child in self.children() {
885            let len = child.len();
886            if offset <= cursor && cursor < offset + len {
887                return child.leaf_after(cursor);
888            }
889            offset += len;
890        }
891
892        None
893    }
894
895    /// Get the leaf at the specified byte offset.
896    pub fn leaf_at(&self, cursor: usize, side: Side) -> Option<Self> {
897        match side {
898            Side::Before => self.leaf_before(cursor),
899            Side::After => self.leaf_after(cursor),
900        }
901    }
902
903    /// Find the rightmost contained non-trivia leaf.
904    pub fn rightmost_leaf(&self) -> Option<Self> {
905        if self.is_leaf() && !self.kind().is_trivia() {
906            return Some(self.clone());
907        }
908
909        for child in self.children().rev() {
910            if let Some(leaf) = child.rightmost_leaf() {
911                return Some(leaf);
912            }
913        }
914
915        None
916    }
917
918    /// Get the leftmost non-trivia leaf after this node.
919    pub fn next_leaf(&self) -> Option<Self> {
920        let mut node = self.clone();
921        while let Some(next) = node.next_sibling() {
922            if let Some(leaf) = next.leftmost_leaf() {
923                return Some(leaf);
924            }
925            node = next;
926        }
927        self.parent()?.next_leaf()
928    }
929}
930
931impl Deref for LinkedNode<'_> {
932    type Target = SyntaxNode;
933
934    /// Dereference to a syntax node. Note that this shortens the lifetime, so
935    /// you may need to use [`get()`](Self::get) instead in some situations.
936    fn deref(&self) -> &Self::Target {
937        self.get()
938    }
939}
940
941impl Debug for LinkedNode<'_> {
942    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
943        self.node.fmt(f)
944    }
945}
946
947/// An iterator over the children of a linked node.
948pub struct LinkedChildren<'a> {
949    parent: Rc<LinkedNode<'a>>,
950    iter: std::iter::Enumerate<std::slice::Iter<'a, SyntaxNode>>,
951    front: usize,
952    back: usize,
953}
954
955impl<'a> Iterator for LinkedChildren<'a> {
956    type Item = LinkedNode<'a>;
957
958    fn next(&mut self) -> Option<Self::Item> {
959        self.iter.next().map(|(index, node)| {
960            let offset = self.front;
961            self.front += node.len();
962            LinkedNode {
963                node,
964                parent: Some(self.parent.clone()),
965                index,
966                offset,
967            }
968        })
969    }
970
971    fn size_hint(&self) -> (usize, Option<usize>) {
972        self.iter.size_hint()
973    }
974}
975
976impl DoubleEndedIterator for LinkedChildren<'_> {
977    fn next_back(&mut self) -> Option<Self::Item> {
978        self.iter.next_back().map(|(index, node)| {
979            self.back -= node.len();
980            LinkedNode {
981                node,
982                parent: Some(self.parent.clone()),
983                index,
984                offset: self.back,
985            }
986        })
987    }
988}
989
990impl ExactSizeIterator for LinkedChildren<'_> {}
991
992/// Result of numbering a node within an interval.
993pub(super) type NumberingResult = Result<(), Unnumberable>;
994
995/// Indicates that a node cannot be numbered within a given interval.
996#[derive(Debug, Copy, Clone, Eq, PartialEq)]
997pub(super) struct Unnumberable;
998
999impl Display for Unnumberable {
1000    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
1001        f.pad("cannot number within this interval")
1002    }
1003}
1004
1005impl std::error::Error for Unnumberable {}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::*;
1010    use crate::Source;
1011
1012    #[test]
1013    fn test_linked_node() {
1014        let source = Source::detached("#set text(12pt, red)");
1015
1016        // Find "text" with Before.
1017        let node = LinkedNode::new(source.root()).leaf_at(7, Side::Before).unwrap();
1018        assert_eq!(node.offset(), 5);
1019        assert_eq!(node.text(), "text");
1020
1021        // Find "text" with After.
1022        let node = LinkedNode::new(source.root()).leaf_at(7, Side::After).unwrap();
1023        assert_eq!(node.offset(), 5);
1024        assert_eq!(node.text(), "text");
1025
1026        // Go back to "#set". Skips the space.
1027        let prev = node.prev_sibling().unwrap();
1028        assert_eq!(prev.offset(), 1);
1029        assert_eq!(prev.text(), "set");
1030    }
1031
1032    #[test]
1033    fn test_linked_node_non_trivia_leaf() {
1034        let source = Source::detached("#set fun(12pt, red)");
1035        let leaf = LinkedNode::new(source.root()).leaf_at(6, Side::Before).unwrap();
1036        let prev = leaf.prev_leaf().unwrap();
1037        assert_eq!(leaf.text(), "fun");
1038        assert_eq!(prev.text(), "set");
1039
1040        // Check position 9 with Before.
1041        let source = Source::detached("#let x = 10");
1042        let leaf = LinkedNode::new(source.root()).leaf_at(9, Side::Before).unwrap();
1043        let prev = leaf.prev_leaf().unwrap();
1044        let next = leaf.next_leaf().unwrap();
1045        assert_eq!(prev.text(), "=");
1046        assert_eq!(leaf.text(), " ");
1047        assert_eq!(next.text(), "10");
1048
1049        // Check position 9 with After.
1050        let source = Source::detached("#let x = 10");
1051        let leaf = LinkedNode::new(source.root()).leaf_at(9, Side::After).unwrap();
1052        let prev = leaf.prev_leaf().unwrap();
1053        assert!(leaf.next_leaf().is_none());
1054        assert_eq!(prev.text(), "=");
1055        assert_eq!(leaf.text(), "10");
1056    }
1057}