selene_lib/ast_util/
loop_tracker.rs1#![allow(dead_code)]
3
4use full_moon::{ast, node::Node, visitors::Visitor};
5
6#[derive(Debug, Clone, Copy)]
7struct LoopDepth {
8 byte: usize,
9 depth: u32,
10}
11
12#[derive(Debug)]
13pub struct LoopTracker {
14 loop_depths: Vec<LoopDepth>,
15}
16
17impl LoopTracker {
18 pub fn new(ast: &ast::Ast) -> Self {
19 let mut visitor = LoopTrackerVisitor {
20 loop_depths: Vec::new(),
21 depth: 0,
22 };
23
24 visitor.visit_ast(ast);
25
26 assert_eq!(
27 visitor.depth, 0,
28 "Loop depth at the end of a loop tracker should be 0"
29 );
30
31 let mut loop_depths = visitor.loop_depths;
32
33 loop_depths.sort_by_cached_key(|loop_depth| loop_depth.byte);
34
35 Self { loop_depths }
36 }
37
38 pub fn depth_at_byte(&self, byte: usize) -> u32 {
39 match self
40 .loop_depths
41 .binary_search_by_key(&byte, |loop_depth| loop_depth.byte)
42 {
43 Ok(index) => self.loop_depths[index].depth,
44 Err(index) => {
45 if index == 0 {
46 0
47 } else {
48 self.loop_depths[index - 1].depth
49 }
50 }
51 }
52 }
53}
54
55struct LoopTrackerVisitor {
56 loop_depths: Vec<LoopDepth>,
57 depth: u32,
58}
59
60impl LoopTrackerVisitor {
61 fn add_loop_depth(&mut self, node: impl Node) {
62 self.depth += 1;
63
64 let Some((start, _)) = node.range() else {
65 return;
66 };
67
68 self.loop_depths.push(LoopDepth {
69 byte: start.bytes(),
70 depth: self.depth,
71 });
72 }
73
74 fn remove_loop_depth(&mut self, node: impl Node) {
75 self.depth -= 1;
76
77 let Some((_, end)) = node.range() else {
78 return;
79 };
80
81 self.loop_depths.push(LoopDepth {
82 byte: end.bytes(),
83 depth: self.depth,
84 });
85 }
86}
87
88impl Visitor for LoopTrackerVisitor {
89 fn visit_generic_for(&mut self, node: &ast::GenericFor) {
90 self.add_loop_depth(node.block());
91 }
92
93 fn visit_generic_for_end(&mut self, node: &ast::GenericFor) {
94 self.remove_loop_depth(node);
95 }
96
97 fn visit_numeric_for(&mut self, node: &ast::NumericFor) {
98 self.add_loop_depth(node.block());
99 }
100
101 fn visit_numeric_for_end(&mut self, node: &ast::NumericFor) {
102 self.remove_loop_depth(node);
103 }
104
105 fn visit_while(&mut self, node: &ast::While) {
106 self.add_loop_depth(node.block());
107 }
108
109 fn visit_while_end(&mut self, node: &ast::While) {
110 self.remove_loop_depth(node);
111 }
112
113 fn visit_repeat(&mut self, node: &ast::Repeat) {
114 self.add_loop_depth(node.block());
115 }
116
117 fn visit_repeat_end(&mut self, node: &ast::Repeat) {
118 self.remove_loop_depth(node);
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 use regex::Regex;
127
128 fn expected_depths(code: &str) -> Vec<(usize, u32)> {
129 let mut depths = Vec::new();
130
131 let regex = Regex::new(r#"expect\((\d+)\)"#).unwrap();
132
133 for regex_match in regex.captures_iter(code) {
134 depths.push((
135 regex_match.get(0).unwrap().start(),
136 regex_match.get(1).unwrap().as_str().parse().unwrap(),
137 ));
138 }
139
140 depths
141 }
142
143 fn test_depths(code: &str) {
144 let expected_depths = expected_depths(code);
145 assert!(!expected_depths.is_empty());
146
147 let ast = full_moon::parse(code).unwrap();
148
149 let loop_tracker = LoopTracker::new(&ast);
150
151 for (byte, expected_depth) in expected_depths {
152 let actual_depth = loop_tracker.depth_at_byte(byte);
153
154 assert_eq!(actual_depth, expected_depth);
155 }
156 }
157
158 #[test]
159 fn loop_tracker() {
160 test_depths(
161 r#"
162 expect(0)
163
164 while true do
165 expect(1)
166
167 for i, v in pairs({}) do
168 expect(2)
169
170 repeat
171 expect(3)
172 until true
173 end
174
175 expect(1)
176 end
177
178 expect(0)
179 "#,
180 );
181 }
182}