1use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40#[derive(Debug, Clone)]
42pub struct CallGraph {
43 edges: HashMap<String, HashSet<String>>,
45 words: HashSet<String>,
47 recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53 pub fn build(program: &Program) -> Self {
58 let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59 let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61 for word in &program.words {
62 let callees = extract_calls(&word.body, &words);
63 edges.insert(word.name.clone(), callees);
64 }
65
66 let mut graph = CallGraph {
67 edges,
68 words,
69 recursive_sccs: Vec::new(),
70 };
71
72 graph.recursive_sccs = graph.find_sccs();
74
75 graph
76 }
77
78 pub fn is_recursive(&self, word: &str) -> bool {
80 self.recursive_sccs.iter().any(|scc| scc.contains(word))
81 }
82
83 pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85 self.recursive_sccs
86 .iter()
87 .any(|scc| scc.contains(word1) && scc.contains(word2))
88 }
89
90 pub fn recursive_cycles(&self) -> &[HashSet<String>] {
92 &self.recursive_sccs
93 }
94
95 pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
97 self.edges.get(word)
98 }
99
100 fn find_sccs(&self) -> Vec<HashSet<String>> {
106 let mut index_counter = 0;
107 let mut stack: Vec<String> = Vec::new();
108 let mut on_stack: HashSet<String> = HashSet::new();
109 let mut indices: HashMap<String, usize> = HashMap::new();
110 let mut lowlinks: HashMap<String, usize> = HashMap::new();
111 let mut sccs: Vec<HashSet<String>> = Vec::new();
112
113 for word in &self.words {
114 if !indices.contains_key(word) {
115 self.tarjan_visit(
116 word,
117 &mut index_counter,
118 &mut stack,
119 &mut on_stack,
120 &mut indices,
121 &mut lowlinks,
122 &mut sccs,
123 );
124 }
125 }
126
127 sccs.into_iter()
129 .filter(|scc| {
130 if scc.len() > 1 {
131 true
133 } else if scc.len() == 1 {
134 let word = scc.iter().next().unwrap();
136 self.edges
137 .get(word)
138 .map(|callees| callees.contains(word))
139 .unwrap_or(false)
140 } else {
141 false
142 }
143 })
144 .collect()
145 }
146
147 #[allow(clippy::too_many_arguments)]
149 fn tarjan_visit(
150 &self,
151 word: &str,
152 index_counter: &mut usize,
153 stack: &mut Vec<String>,
154 on_stack: &mut HashSet<String>,
155 indices: &mut HashMap<String, usize>,
156 lowlinks: &mut HashMap<String, usize>,
157 sccs: &mut Vec<HashSet<String>>,
158 ) {
159 let index = *index_counter;
160 *index_counter += 1;
161 indices.insert(word.to_string(), index);
162 lowlinks.insert(word.to_string(), index);
163 stack.push(word.to_string());
164 on_stack.insert(word.to_string());
165
166 if let Some(callees) = self.edges.get(word) {
168 for callee in callees {
169 if !self.words.contains(callee) {
170 continue;
172 }
173 if !indices.contains_key(callee) {
174 self.tarjan_visit(
176 callee,
177 index_counter,
178 stack,
179 on_stack,
180 indices,
181 lowlinks,
182 sccs,
183 );
184 let callee_lowlink = *lowlinks.get(callee).unwrap();
185 let word_lowlink = lowlinks.get_mut(word).unwrap();
186 *word_lowlink = (*word_lowlink).min(callee_lowlink);
187 } else if on_stack.contains(callee) {
188 let callee_index = *indices.get(callee).unwrap();
190 let word_lowlink = lowlinks.get_mut(word).unwrap();
191 *word_lowlink = (*word_lowlink).min(callee_index);
192 }
193 }
194 }
195
196 if lowlinks.get(word) == indices.get(word) {
198 let mut scc = HashSet::new();
199 loop {
200 let w = stack.pop().unwrap();
201 on_stack.remove(&w);
202 scc.insert(w.clone());
203 if w == word {
204 break;
205 }
206 }
207 sccs.push(scc);
208 }
209 }
210}
211
212fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
216 let mut calls = HashSet::new();
217
218 for stmt in statements {
219 extract_calls_from_statement(stmt, known_words, &mut calls);
220 }
221
222 calls
223}
224
225fn extract_calls_from_statement(
227 stmt: &Statement,
228 known_words: &HashSet<String>,
229 calls: &mut HashSet<String>,
230) {
231 match stmt {
232 Statement::WordCall { name, .. } => {
233 if known_words.contains(name) {
235 calls.insert(name.clone());
236 }
237 }
238 Statement::If {
239 then_branch,
240 else_branch,
241 } => {
242 for s in then_branch {
243 extract_calls_from_statement(s, known_words, calls);
244 }
245 if let Some(else_stmts) = else_branch {
246 for s in else_stmts {
247 extract_calls_from_statement(s, known_words, calls);
248 }
249 }
250 }
251 Statement::Quotation { body, .. } => {
252 for s in body {
253 extract_calls_from_statement(s, known_words, calls);
254 }
255 }
256 Statement::Match { arms } => {
257 for arm in arms {
258 for s in &arm.body {
259 extract_calls_from_statement(s, known_words, calls);
260 }
261 }
262 }
263 Statement::IntLiteral(_)
265 | Statement::FloatLiteral(_)
266 | Statement::BoolLiteral(_)
267 | Statement::StringLiteral(_)
268 | Statement::Symbol(_) => {}
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::ast::WordDef;
276
277 fn make_word(name: &str, calls: Vec<&str>) -> WordDef {
278 let body = calls
279 .into_iter()
280 .map(|c| Statement::WordCall {
281 name: c.to_string(),
282 span: None,
283 })
284 .collect();
285 WordDef {
286 name: name.to_string(),
287 effect: None,
288 body,
289 source: None,
290 allowed_lints: vec![],
291 }
292 }
293
294 #[test]
295 fn test_no_recursion() {
296 let program = Program {
297 includes: vec![],
298 unions: vec![],
299 words: vec![
300 make_word("foo", vec!["bar"]),
301 make_word("bar", vec![]),
302 make_word("baz", vec!["foo"]),
303 ],
304 };
305
306 let graph = CallGraph::build(&program);
307 assert!(!graph.is_recursive("foo"));
308 assert!(!graph.is_recursive("bar"));
309 assert!(!graph.is_recursive("baz"));
310 assert!(graph.recursive_cycles().is_empty());
311 }
312
313 #[test]
314 fn test_direct_recursion() {
315 let program = Program {
316 includes: vec![],
317 unions: vec![],
318 words: vec![
319 make_word("countdown", vec!["countdown"]),
320 make_word("helper", vec![]),
321 ],
322 };
323
324 let graph = CallGraph::build(&program);
325 assert!(graph.is_recursive("countdown"));
326 assert!(!graph.is_recursive("helper"));
327 assert_eq!(graph.recursive_cycles().len(), 1);
328 }
329
330 #[test]
331 fn test_mutual_recursion_pair() {
332 let program = Program {
333 includes: vec![],
334 unions: vec![],
335 words: vec![
336 make_word("ping", vec!["pong"]),
337 make_word("pong", vec!["ping"]),
338 ],
339 };
340
341 let graph = CallGraph::build(&program);
342 assert!(graph.is_recursive("ping"));
343 assert!(graph.is_recursive("pong"));
344 assert!(graph.are_mutually_recursive("ping", "pong"));
345 assert_eq!(graph.recursive_cycles().len(), 1);
346 assert_eq!(graph.recursive_cycles()[0].len(), 2);
347 }
348
349 #[test]
350 fn test_mutual_recursion_triple() {
351 let program = Program {
352 includes: vec![],
353 unions: vec![],
354 words: vec![
355 make_word("a", vec!["b"]),
356 make_word("b", vec!["c"]),
357 make_word("c", vec!["a"]),
358 ],
359 };
360
361 let graph = CallGraph::build(&program);
362 assert!(graph.is_recursive("a"));
363 assert!(graph.is_recursive("b"));
364 assert!(graph.is_recursive("c"));
365 assert!(graph.are_mutually_recursive("a", "b"));
366 assert!(graph.are_mutually_recursive("b", "c"));
367 assert!(graph.are_mutually_recursive("a", "c"));
368 assert_eq!(graph.recursive_cycles().len(), 1);
369 assert_eq!(graph.recursive_cycles()[0].len(), 3);
370 }
371
372 #[test]
373 fn test_multiple_independent_cycles() {
374 let program = Program {
375 includes: vec![],
376 unions: vec![],
377 words: vec![
378 make_word("ping", vec!["pong"]),
380 make_word("pong", vec!["ping"]),
381 make_word("even", vec!["odd"]),
383 make_word("odd", vec!["even"]),
384 make_word("main", vec!["ping", "even"]),
386 ],
387 };
388
389 let graph = CallGraph::build(&program);
390 assert!(graph.is_recursive("ping"));
391 assert!(graph.is_recursive("pong"));
392 assert!(graph.is_recursive("even"));
393 assert!(graph.is_recursive("odd"));
394 assert!(!graph.is_recursive("main"));
395
396 assert!(graph.are_mutually_recursive("ping", "pong"));
397 assert!(graph.are_mutually_recursive("even", "odd"));
398 assert!(!graph.are_mutually_recursive("ping", "even"));
399
400 assert_eq!(graph.recursive_cycles().len(), 2);
401 }
402
403 #[test]
404 fn test_calls_to_unknown_words() {
405 let program = Program {
407 includes: vec![],
408 unions: vec![],
409 words: vec![make_word("foo", vec!["dup", "drop", "unknown_builtin"])],
410 };
411
412 let graph = CallGraph::build(&program);
413 assert!(!graph.is_recursive("foo"));
414 assert!(graph.callees("foo").unwrap().is_empty());
416 }
417
418 #[test]
419 fn test_cycle_with_builtins_interspersed() {
420 let program = Program {
423 includes: vec![],
424 unions: vec![],
425 words: vec![
426 make_word("foo", vec!["dup", "drop", "bar"]),
427 make_word("bar", vec!["swap", "foo"]),
428 ],
429 };
430
431 let graph = CallGraph::build(&program);
432 assert!(graph.is_recursive("foo"));
434 assert!(graph.is_recursive("bar"));
435 assert!(graph.are_mutually_recursive("foo", "bar"));
436
437 let foo_callees = graph.callees("foo").unwrap();
439 assert!(foo_callees.contains("bar"));
440 assert!(!foo_callees.contains("dup"));
441 assert!(!foo_callees.contains("drop"));
442 }
443
444 #[test]
445 fn test_cycle_through_quotation() {
446 use crate::ast::Statement;
449
450 let program = Program {
451 includes: vec![],
452 unions: vec![],
453 words: vec![
454 WordDef {
455 name: "foo".to_string(),
456 effect: None,
457 body: vec![
458 Statement::Quotation {
459 id: 0,
460 body: vec![Statement::WordCall {
461 name: "bar".to_string(),
462 span: None,
463 }],
464 span: None,
465 },
466 Statement::WordCall {
467 name: "call".to_string(),
468 span: None,
469 },
470 ],
471 source: None,
472 allowed_lints: vec![],
473 },
474 make_word("bar", vec!["foo"]),
475 ],
476 };
477
478 let graph = CallGraph::build(&program);
479 assert!(graph.is_recursive("foo"));
481 assert!(graph.is_recursive("bar"));
482 assert!(graph.are_mutually_recursive("foo", "bar"));
483 }
484
485 #[test]
486 fn test_cycle_through_if_branch() {
487 use crate::ast::Statement;
489
490 let program = Program {
491 includes: vec![],
492 unions: vec![],
493 words: vec![
494 WordDef {
495 name: "even".to_string(),
496 effect: None,
497 body: vec![Statement::If {
498 then_branch: vec![],
499 else_branch: Some(vec![Statement::WordCall {
500 name: "odd".to_string(),
501 span: None,
502 }]),
503 }],
504 source: None,
505 allowed_lints: vec![],
506 },
507 WordDef {
508 name: "odd".to_string(),
509 effect: None,
510 body: vec![Statement::If {
511 then_branch: vec![],
512 else_branch: Some(vec![Statement::WordCall {
513 name: "even".to_string(),
514 span: None,
515 }]),
516 }],
517 source: None,
518 allowed_lints: vec![],
519 },
520 ],
521 };
522
523 let graph = CallGraph::build(&program);
524 assert!(graph.is_recursive("even"));
525 assert!(graph.is_recursive("odd"));
526 assert!(graph.are_mutually_recursive("even", "odd"));
527 }
528}