selene_lib/ast_util/
loop_tracker.rs

1// Remove this once loop_tracker is used again
2#![allow(dead_code)]
3
4use full_moon::{ast, node::Node, visitors::Visitor};
5
6#[derive(Debug, Clone, Copy)]
7struct LoopDepth {
8    byte: usize,
9    depth: u32,
10}
11
12#[derive(Debug)]
13pub struct LoopTracker {
14    loop_depths: Vec<LoopDepth>,
15}
16
17impl LoopTracker {
18    pub fn new(ast: &ast::Ast) -> Self {
19        let mut visitor = LoopTrackerVisitor {
20            loop_depths: Vec::new(),
21            depth: 0,
22        };
23
24        visitor.visit_ast(ast);
25
26        assert_eq!(
27            visitor.depth, 0,
28            "Loop depth at the end of a loop tracker should be 0"
29        );
30
31        let mut loop_depths = visitor.loop_depths;
32
33        loop_depths.sort_by_cached_key(|loop_depth| loop_depth.byte);
34
35        Self { loop_depths }
36    }
37
38    pub fn depth_at_byte(&self, byte: usize) -> u32 {
39        match self
40            .loop_depths
41            .binary_search_by_key(&byte, |loop_depth| loop_depth.byte)
42        {
43            Ok(index) => self.loop_depths[index].depth,
44            Err(index) => {
45                if index == 0 {
46                    0
47                } else {
48                    self.loop_depths[index - 1].depth
49                }
50            }
51        }
52    }
53}
54
55struct LoopTrackerVisitor {
56    loop_depths: Vec<LoopDepth>,
57    depth: u32,
58}
59
60impl LoopTrackerVisitor {
61    fn add_loop_depth(&mut self, node: impl Node) {
62        self.depth += 1;
63
64        let Some((start, _)) = node.range() else {
65            return;
66        };
67
68        self.loop_depths.push(LoopDepth {
69            byte: start.bytes(),
70            depth: self.depth,
71        });
72    }
73
74    fn remove_loop_depth(&mut self, node: impl Node) {
75        self.depth -= 1;
76
77        let Some((_, end)) = node.range() else {
78            return;
79        };
80
81        self.loop_depths.push(LoopDepth {
82            byte: end.bytes(),
83            depth: self.depth,
84        });
85    }
86}
87
88impl Visitor for LoopTrackerVisitor {
89    fn visit_generic_for(&mut self, node: &ast::GenericFor) {
90        self.add_loop_depth(node.block());
91    }
92
93    fn visit_generic_for_end(&mut self, node: &ast::GenericFor) {
94        self.remove_loop_depth(node);
95    }
96
97    fn visit_numeric_for(&mut self, node: &ast::NumericFor) {
98        self.add_loop_depth(node.block());
99    }
100
101    fn visit_numeric_for_end(&mut self, node: &ast::NumericFor) {
102        self.remove_loop_depth(node);
103    }
104
105    fn visit_while(&mut self, node: &ast::While) {
106        self.add_loop_depth(node.block());
107    }
108
109    fn visit_while_end(&mut self, node: &ast::While) {
110        self.remove_loop_depth(node);
111    }
112
113    fn visit_repeat(&mut self, node: &ast::Repeat) {
114        self.add_loop_depth(node.block());
115    }
116
117    fn visit_repeat_end(&mut self, node: &ast::Repeat) {
118        self.remove_loop_depth(node);
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    use regex::Regex;
127
128    fn expected_depths(code: &str) -> Vec<(usize, u32)> {
129        let mut depths = Vec::new();
130
131        let regex = Regex::new(r#"expect\((\d+)\)"#).unwrap();
132
133        for regex_match in regex.captures_iter(code) {
134            depths.push((
135                regex_match.get(0).unwrap().start(),
136                regex_match.get(1).unwrap().as_str().parse().unwrap(),
137            ));
138        }
139
140        depths
141    }
142
143    fn test_depths(code: &str) {
144        let expected_depths = expected_depths(code);
145        assert!(!expected_depths.is_empty());
146
147        let ast = full_moon::parse(code).unwrap();
148
149        let loop_tracker = LoopTracker::new(&ast);
150
151        for (byte, expected_depth) in expected_depths {
152            let actual_depth = loop_tracker.depth_at_byte(byte);
153
154            assert_eq!(actual_depth, expected_depth);
155        }
156    }
157
158    #[test]
159    fn loop_tracker() {
160        test_depths(
161            r#"
162            expect(0)
163
164            while true do
165                expect(1)
166
167                for i, v in pairs({}) do
168                    expect(2)
169
170                    repeat
171                        expect(3)
172                    until true
173                end
174
175                expect(1)
176            end
177
178            expect(0)
179        "#,
180        );
181    }
182}