1use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40#[derive(Debug, Clone)]
42pub struct CallGraph {
43 edges: HashMap<String, HashSet<String>>,
45 words: HashSet<String>,
47 recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53 pub fn build(program: &Program) -> Self {
58 let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59 let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61 for word in &program.words {
62 let callees = extract_calls(&word.body, &words);
63 edges.insert(word.name.clone(), callees);
64 }
65
66 let mut graph = CallGraph {
67 edges,
68 words,
69 recursive_sccs: Vec::new(),
70 };
71
72 graph.recursive_sccs = graph.find_sccs();
74
75 graph
76 }
77
78 pub fn is_recursive(&self, word: &str) -> bool {
80 self.recursive_sccs.iter().any(|scc| scc.contains(word))
81 }
82
83 pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85 self.recursive_sccs
86 .iter()
87 .any(|scc| scc.contains(word1) && scc.contains(word2))
88 }
89
90 pub fn recursive_cycles(&self) -> &[HashSet<String>] {
92 &self.recursive_sccs
93 }
94
95 pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
97 self.edges.get(word)
98 }
99
100 fn find_sccs(&self) -> Vec<HashSet<String>> {
106 let mut index_counter = 0;
107 let mut stack: Vec<String> = Vec::new();
108 let mut on_stack: HashSet<String> = HashSet::new();
109 let mut indices: HashMap<String, usize> = HashMap::new();
110 let mut lowlinks: HashMap<String, usize> = HashMap::new();
111 let mut sccs: Vec<HashSet<String>> = Vec::new();
112
113 for word in &self.words {
114 if !indices.contains_key(word) {
115 self.tarjan_visit(
116 word,
117 &mut index_counter,
118 &mut stack,
119 &mut on_stack,
120 &mut indices,
121 &mut lowlinks,
122 &mut sccs,
123 );
124 }
125 }
126
127 sccs.into_iter()
129 .filter(|scc| {
130 if scc.len() > 1 {
131 true
133 } else if scc.len() == 1 {
134 let word = scc.iter().next().unwrap();
136 self.edges
137 .get(word)
138 .map(|callees| callees.contains(word))
139 .unwrap_or(false)
140 } else {
141 false
142 }
143 })
144 .collect()
145 }
146
147 #[allow(clippy::too_many_arguments)]
149 fn tarjan_visit(
150 &self,
151 word: &str,
152 index_counter: &mut usize,
153 stack: &mut Vec<String>,
154 on_stack: &mut HashSet<String>,
155 indices: &mut HashMap<String, usize>,
156 lowlinks: &mut HashMap<String, usize>,
157 sccs: &mut Vec<HashSet<String>>,
158 ) {
159 let index = *index_counter;
160 *index_counter += 1;
161 indices.insert(word.to_string(), index);
162 lowlinks.insert(word.to_string(), index);
163 stack.push(word.to_string());
164 on_stack.insert(word.to_string());
165
166 if let Some(callees) = self.edges.get(word) {
168 for callee in callees {
169 if !self.words.contains(callee) {
170 continue;
172 }
173 if !indices.contains_key(callee) {
174 self.tarjan_visit(
176 callee,
177 index_counter,
178 stack,
179 on_stack,
180 indices,
181 lowlinks,
182 sccs,
183 );
184 let callee_lowlink = *lowlinks.get(callee).unwrap();
185 let word_lowlink = lowlinks.get_mut(word).unwrap();
186 *word_lowlink = (*word_lowlink).min(callee_lowlink);
187 } else if on_stack.contains(callee) {
188 let callee_index = *indices.get(callee).unwrap();
190 let word_lowlink = lowlinks.get_mut(word).unwrap();
191 *word_lowlink = (*word_lowlink).min(callee_index);
192 }
193 }
194 }
195
196 if lowlinks.get(word) == indices.get(word) {
198 let mut scc = HashSet::new();
199 loop {
200 let w = stack.pop().unwrap();
201 on_stack.remove(&w);
202 scc.insert(w.clone());
203 if w == word {
204 break;
205 }
206 }
207 sccs.push(scc);
208 }
209 }
210}
211
212fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
216 let mut calls = HashSet::new();
217
218 for stmt in statements {
219 extract_calls_from_statement(stmt, known_words, &mut calls);
220 }
221
222 calls
223}
224
225fn extract_calls_from_statement(
227 stmt: &Statement,
228 known_words: &HashSet<String>,
229 calls: &mut HashSet<String>,
230) {
231 match stmt {
232 Statement::WordCall { name, .. } => {
233 if known_words.contains(name) {
235 calls.insert(name.clone());
236 }
237 }
238 Statement::If {
239 then_branch,
240 else_branch,
241 span: _,
242 } => {
243 for s in then_branch {
244 extract_calls_from_statement(s, known_words, calls);
245 }
246 if let Some(else_stmts) = else_branch {
247 for s in else_stmts {
248 extract_calls_from_statement(s, known_words, calls);
249 }
250 }
251 }
252 Statement::Quotation { body, .. } => {
253 for s in body {
254 extract_calls_from_statement(s, known_words, calls);
255 }
256 }
257 Statement::Match { arms, span: _ } => {
258 for arm in arms {
259 for s in &arm.body {
260 extract_calls_from_statement(s, known_words, calls);
261 }
262 }
263 }
264 Statement::IntLiteral(_)
266 | Statement::FloatLiteral(_)
267 | Statement::BoolLiteral(_)
268 | Statement::StringLiteral(_)
269 | Statement::Symbol(_) => {}
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::ast::WordDef;
277
278 fn make_word(name: &str, calls: Vec<&str>) -> WordDef {
279 let body = calls
280 .into_iter()
281 .map(|c| Statement::WordCall {
282 name: c.to_string(),
283 span: None,
284 })
285 .collect();
286 WordDef {
287 name: name.to_string(),
288 effect: None,
289 body,
290 source: None,
291 allowed_lints: vec![],
292 }
293 }
294
295 #[test]
296 fn test_no_recursion() {
297 let program = Program {
298 includes: vec![],
299 unions: vec![],
300 words: vec![
301 make_word("foo", vec!["bar"]),
302 make_word("bar", vec![]),
303 make_word("baz", vec!["foo"]),
304 ],
305 };
306
307 let graph = CallGraph::build(&program);
308 assert!(!graph.is_recursive("foo"));
309 assert!(!graph.is_recursive("bar"));
310 assert!(!graph.is_recursive("baz"));
311 assert!(graph.recursive_cycles().is_empty());
312 }
313
314 #[test]
315 fn test_direct_recursion() {
316 let program = Program {
317 includes: vec![],
318 unions: vec![],
319 words: vec![
320 make_word("countdown", vec!["countdown"]),
321 make_word("helper", vec![]),
322 ],
323 };
324
325 let graph = CallGraph::build(&program);
326 assert!(graph.is_recursive("countdown"));
327 assert!(!graph.is_recursive("helper"));
328 assert_eq!(graph.recursive_cycles().len(), 1);
329 }
330
331 #[test]
332 fn test_mutual_recursion_pair() {
333 let program = Program {
334 includes: vec![],
335 unions: vec![],
336 words: vec![
337 make_word("ping", vec!["pong"]),
338 make_word("pong", vec!["ping"]),
339 ],
340 };
341
342 let graph = CallGraph::build(&program);
343 assert!(graph.is_recursive("ping"));
344 assert!(graph.is_recursive("pong"));
345 assert!(graph.are_mutually_recursive("ping", "pong"));
346 assert_eq!(graph.recursive_cycles().len(), 1);
347 assert_eq!(graph.recursive_cycles()[0].len(), 2);
348 }
349
350 #[test]
351 fn test_mutual_recursion_triple() {
352 let program = Program {
353 includes: vec![],
354 unions: vec![],
355 words: vec![
356 make_word("a", vec!["b"]),
357 make_word("b", vec!["c"]),
358 make_word("c", vec!["a"]),
359 ],
360 };
361
362 let graph = CallGraph::build(&program);
363 assert!(graph.is_recursive("a"));
364 assert!(graph.is_recursive("b"));
365 assert!(graph.is_recursive("c"));
366 assert!(graph.are_mutually_recursive("a", "b"));
367 assert!(graph.are_mutually_recursive("b", "c"));
368 assert!(graph.are_mutually_recursive("a", "c"));
369 assert_eq!(graph.recursive_cycles().len(), 1);
370 assert_eq!(graph.recursive_cycles()[0].len(), 3);
371 }
372
373 #[test]
374 fn test_multiple_independent_cycles() {
375 let program = Program {
376 includes: vec![],
377 unions: vec![],
378 words: vec![
379 make_word("ping", vec!["pong"]),
381 make_word("pong", vec!["ping"]),
382 make_word("even", vec!["odd"]),
384 make_word("odd", vec!["even"]),
385 make_word("main", vec!["ping", "even"]),
387 ],
388 };
389
390 let graph = CallGraph::build(&program);
391 assert!(graph.is_recursive("ping"));
392 assert!(graph.is_recursive("pong"));
393 assert!(graph.is_recursive("even"));
394 assert!(graph.is_recursive("odd"));
395 assert!(!graph.is_recursive("main"));
396
397 assert!(graph.are_mutually_recursive("ping", "pong"));
398 assert!(graph.are_mutually_recursive("even", "odd"));
399 assert!(!graph.are_mutually_recursive("ping", "even"));
400
401 assert_eq!(graph.recursive_cycles().len(), 2);
402 }
403
404 #[test]
405 fn test_calls_to_unknown_words() {
406 let program = Program {
408 includes: vec![],
409 unions: vec![],
410 words: vec![make_word("foo", vec!["dup", "drop", "unknown_builtin"])],
411 };
412
413 let graph = CallGraph::build(&program);
414 assert!(!graph.is_recursive("foo"));
415 assert!(graph.callees("foo").unwrap().is_empty());
417 }
418
419 #[test]
420 fn test_cycle_with_builtins_interspersed() {
421 let program = Program {
424 includes: vec![],
425 unions: vec![],
426 words: vec![
427 make_word("foo", vec!["dup", "drop", "bar"]),
428 make_word("bar", vec!["swap", "foo"]),
429 ],
430 };
431
432 let graph = CallGraph::build(&program);
433 assert!(graph.is_recursive("foo"));
435 assert!(graph.is_recursive("bar"));
436 assert!(graph.are_mutually_recursive("foo", "bar"));
437
438 let foo_callees = graph.callees("foo").unwrap();
440 assert!(foo_callees.contains("bar"));
441 assert!(!foo_callees.contains("dup"));
442 assert!(!foo_callees.contains("drop"));
443 }
444
445 #[test]
446 fn test_cycle_through_quotation() {
447 use crate::ast::Statement;
450
451 let program = Program {
452 includes: vec![],
453 unions: vec![],
454 words: vec![
455 WordDef {
456 name: "foo".to_string(),
457 effect: None,
458 body: vec![
459 Statement::Quotation {
460 id: 0,
461 body: vec![Statement::WordCall {
462 name: "bar".to_string(),
463 span: None,
464 }],
465 span: None,
466 },
467 Statement::WordCall {
468 name: "call".to_string(),
469 span: None,
470 },
471 ],
472 source: None,
473 allowed_lints: vec![],
474 },
475 make_word("bar", vec!["foo"]),
476 ],
477 };
478
479 let graph = CallGraph::build(&program);
480 assert!(graph.is_recursive("foo"));
482 assert!(graph.is_recursive("bar"));
483 assert!(graph.are_mutually_recursive("foo", "bar"));
484 }
485
486 #[test]
487 fn test_cycle_through_if_branch() {
488 use crate::ast::Statement;
490
491 let program = Program {
492 includes: vec![],
493 unions: vec![],
494 words: vec![
495 WordDef {
496 name: "even".to_string(),
497 effect: None,
498 body: vec![Statement::If {
499 then_branch: vec![],
500 else_branch: Some(vec![Statement::WordCall {
501 name: "odd".to_string(),
502 span: None,
503 }]),
504 span: None,
505 }],
506 source: None,
507 allowed_lints: vec![],
508 },
509 WordDef {
510 name: "odd".to_string(),
511 effect: None,
512 body: vec![Statement::If {
513 then_branch: vec![],
514 else_branch: Some(vec![Statement::WordCall {
515 name: "even".to_string(),
516 span: None,
517 }]),
518 span: None,
519 }],
520 source: None,
521 allowed_lints: vec![],
522 },
523 ],
524 };
525
526 let graph = CallGraph::build(&program);
527 assert!(graph.is_recursive("even"));
528 assert!(graph.is_recursive("odd"));
529 assert!(graph.are_mutually_recursive("even", "odd"));
530 }
531}