rustpython_ruff_python_ast/
find_node.rs1use crate::AnyNodeRef;
2use crate::visitor::source_order::{SourceOrderVisitor, TraversalSignal, walk_node};
3use ruff_text_size::{Ranged, TextRange};
4use std::fmt;
5use std::fmt::Formatter;
6
7pub fn covering_node(root: AnyNodeRef, range: TextRange) -> CoveringNode {
15 struct Visitor<'a> {
16 range: TextRange,
17 found: bool,
18 ancestors: Vec<AnyNodeRef<'a>>,
19 }
20
21 impl<'a> SourceOrderVisitor<'a> for Visitor<'a> {
22 fn enter_node(&mut self, node: AnyNodeRef<'a>) -> TraversalSignal {
23 if !self.found && node.range().contains_range(self.range) {
26 self.ancestors.push(node);
27 TraversalSignal::Traverse
28 } else {
29 TraversalSignal::Skip
30 }
31 }
32
33 fn leave_node(&mut self, node: AnyNodeRef<'a>) {
34 if !self.found && self.ancestors.last() == Some(&node) {
35 self.found = true;
36 }
37 }
38 }
39
40 assert!(
41 root.range().contains_range(range),
42 "Range is not contained within root"
43 );
44
45 let mut visitor = Visitor {
46 range,
47 found: false,
48 ancestors: Vec::new(),
49 };
50
51 walk_node(&mut visitor, root);
52 CoveringNode::from_ancestors(visitor.ancestors)
53}
54
55pub struct CoveringNode<'a> {
57 nodes: Vec<AnyNodeRef<'a>>,
62}
63
64impl<'a> CoveringNode<'a> {
65 pub fn from_ancestors(ancestors: Vec<AnyNodeRef<'a>>) -> Self {
68 Self { nodes: ancestors }
69 }
70
71 pub fn node(&self) -> AnyNodeRef<'a> {
73 *self
74 .nodes
75 .last()
76 .expect("`CoveringNode::nodes` should always be non-empty")
77 }
78
79 pub fn parent(&self) -> Option<AnyNodeRef<'a>> {
81 let penultimate = self.nodes.len().checked_sub(2)?;
82 self.nodes.get(penultimate).copied()
83 }
84
85 pub fn find_first(mut self, f: impl Fn(AnyNodeRef<'a>) -> bool) -> Result<Self, Self> {
91 let Some(index) = self.find_first_index(f) else {
92 return Err(self);
93 };
94 self.nodes.truncate(index + 1);
95 Ok(self)
96 }
97
98 pub fn find_last(mut self, f: impl Fn(AnyNodeRef<'a>) -> bool) -> Result<Self, Self> {
106 let Some(mut index) = self.find_first_index(&f) else {
107 return Err(self);
108 };
109 while index > 0 && f(self.nodes[index - 1]) {
110 index -= 1;
111 }
112 self.nodes.truncate(index + 1);
113 Ok(self)
114 }
115
116 pub fn ancestors(&self) -> impl DoubleEndedIterator<Item = AnyNodeRef<'a>> + '_ {
119 self.nodes.iter().copied().rev()
120 }
121
122 fn find_first_index(&self, f: impl Fn(AnyNodeRef<'a>) -> bool) -> Option<usize> {
128 self.nodes.iter().rposition(|node| f(*node))
129 }
130}
131
132impl fmt::Debug for CoveringNode<'_> {
133 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
134 f.debug_tuple("CoveringNode").field(&self.node()).finish()
135 }
136}