Skip to main content

seqc/
call_graph.rs

1//! Call graph analysis for detecting mutual recursion
2//!
3//! This module builds a call graph from a Seq program and detects
4//! strongly connected components (SCCs) to identify mutual recursion cycles.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! let call_graph = CallGraph::build(&program);
10//! let cycles = call_graph.recursive_cycles();
11//! ```
12//!
13//! # Primary Use Cases
14//!
15//! 1. **Type checker divergence detection**: The type checker uses the call graph
16//!    to identify mutually recursive tail calls, enabling correct type inference
17//!    for patterns like even/odd that would otherwise require branch unification.
18//!
19//! 2. **Future optimizations**: The call graph infrastructure can support dead code
20//!    detection, inlining decisions, and diagnostic tools.
21//!
22//! # Implementation Details
23//!
24//! - **Algorithm**: Tarjan's SCC algorithm, O(V + E) time complexity
25//! - **Builtins**: Calls to builtins/external words are excluded from the graph
26//!   (they don't affect recursion detection since they always return)
27//! - **Quotations**: Calls within quotations are included in the analysis
28//! - **Match arms**: Calls within match arms are included in the analysis
29//!
30//! # Note on Tail Call Optimization
31//!
32//! The existing codegen already emits `musttail` for all tail calls to user-defined
33//! words (see `codegen/statements.rs`). This means mutual TCO works automatically
34//! without needing explicit call graph checks in codegen. The call graph is primarily
35//! used for type checking, not for enabling TCO.
36
37use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40/// A call graph representing which words call which other words.
41#[derive(Debug, Clone)]
42pub struct CallGraph {
43    /// Map from word name to the set of words it calls
44    edges: HashMap<String, HashSet<String>>,
45    /// All word names in the program
46    words: HashSet<String>,
47    /// Strongly connected components with more than one member (mutual recursion)
48    /// or single members that call themselves (direct recursion)
49    recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53    /// Build a call graph from a program.
54    ///
55    /// This extracts all word-to-word call relationships, including calls
56    /// within quotations, if branches, and match arms.
57    pub fn build(program: &Program) -> Self {
58        let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59        let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61        for word in &program.words {
62            let callees = extract_calls(&word.body, &words);
63            edges.insert(word.name.clone(), callees);
64        }
65
66        let mut graph = CallGraph {
67            edges,
68            words,
69            recursive_sccs: Vec::new(),
70        };
71
72        // Compute SCCs and identify recursive cycles
73        graph.recursive_sccs = graph.find_sccs();
74
75        graph
76    }
77
78    /// Check if a word is part of any recursive cycle (direct or mutual).
79    pub fn is_recursive(&self, word: &str) -> bool {
80        self.recursive_sccs.iter().any(|scc| scc.contains(word))
81    }
82
83    /// Check if two words are in the same recursive cycle (mutually recursive).
84    pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85        self.recursive_sccs
86            .iter()
87            .any(|scc| scc.contains(word1) && scc.contains(word2))
88    }
89
90    /// Get all recursive cycles (SCCs with recursion).
91    pub fn recursive_cycles(&self) -> &[HashSet<String>] {
92        &self.recursive_sccs
93    }
94
95    /// Get the words that a given word calls.
96    pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
97        self.edges.get(word)
98    }
99
100    /// Find strongly connected components using Tarjan's algorithm.
101    ///
102    /// Returns only SCCs that represent recursion:
103    /// - Multi-word SCCs (mutual recursion)
104    /// - Single-word SCCs where the word calls itself (direct recursion)
105    fn find_sccs(&self) -> Vec<HashSet<String>> {
106        let mut index_counter = 0;
107        let mut stack: Vec<String> = Vec::new();
108        let mut on_stack: HashSet<String> = HashSet::new();
109        let mut indices: HashMap<String, usize> = HashMap::new();
110        let mut lowlinks: HashMap<String, usize> = HashMap::new();
111        let mut sccs: Vec<HashSet<String>> = Vec::new();
112
113        for word in &self.words {
114            if !indices.contains_key(word) {
115                self.tarjan_visit(
116                    word,
117                    &mut index_counter,
118                    &mut stack,
119                    &mut on_stack,
120                    &mut indices,
121                    &mut lowlinks,
122                    &mut sccs,
123                );
124            }
125        }
126
127        // Filter to only recursive SCCs
128        sccs.into_iter()
129            .filter(|scc| {
130                if scc.len() > 1 {
131                    // Multi-word SCC = mutual recursion
132                    true
133                } else if scc.len() == 1 {
134                    // Single-word SCC: check if it calls itself
135                    let word = scc.iter().next().unwrap();
136                    self.edges
137                        .get(word)
138                        .map(|callees| callees.contains(word))
139                        .unwrap_or(false)
140                } else {
141                    false
142                }
143            })
144            .collect()
145    }
146
147    /// Tarjan's algorithm recursive visit.
148    #[allow(clippy::too_many_arguments)]
149    fn tarjan_visit(
150        &self,
151        word: &str,
152        index_counter: &mut usize,
153        stack: &mut Vec<String>,
154        on_stack: &mut HashSet<String>,
155        indices: &mut HashMap<String, usize>,
156        lowlinks: &mut HashMap<String, usize>,
157        sccs: &mut Vec<HashSet<String>>,
158    ) {
159        let index = *index_counter;
160        *index_counter += 1;
161        indices.insert(word.to_string(), index);
162        lowlinks.insert(word.to_string(), index);
163        stack.push(word.to_string());
164        on_stack.insert(word.to_string());
165
166        // Visit all callees
167        if let Some(callees) = self.edges.get(word) {
168            for callee in callees {
169                if !self.words.contains(callee) {
170                    // External word (builtin), skip
171                    continue;
172                }
173                if !indices.contains_key(callee) {
174                    // Not yet visited
175                    self.tarjan_visit(
176                        callee,
177                        index_counter,
178                        stack,
179                        on_stack,
180                        indices,
181                        lowlinks,
182                        sccs,
183                    );
184                    let callee_lowlink = *lowlinks.get(callee).unwrap();
185                    let word_lowlink = lowlinks.get_mut(word).unwrap();
186                    *word_lowlink = (*word_lowlink).min(callee_lowlink);
187                } else if on_stack.contains(callee) {
188                    // Callee is on stack, part of current SCC
189                    let callee_index = *indices.get(callee).unwrap();
190                    let word_lowlink = lowlinks.get_mut(word).unwrap();
191                    *word_lowlink = (*word_lowlink).min(callee_index);
192                }
193            }
194        }
195
196        // If word is a root node, pop the SCC
197        if lowlinks.get(word) == indices.get(word) {
198            let mut scc = HashSet::new();
199            loop {
200                let w = stack.pop().unwrap();
201                on_stack.remove(&w);
202                scc.insert(w.clone());
203                if w == word {
204                    break;
205                }
206            }
207            sccs.push(scc);
208        }
209    }
210}
211
212/// Extract all word calls from a list of statements.
213///
214/// This recursively descends into quotations, if branches, and match arms.
215fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
216    let mut calls = HashSet::new();
217
218    for stmt in statements {
219        extract_calls_from_statement(stmt, known_words, &mut calls);
220    }
221
222    calls
223}
224
225/// Extract word calls from a single statement.
226fn extract_calls_from_statement(
227    stmt: &Statement,
228    known_words: &HashSet<String>,
229    calls: &mut HashSet<String>,
230) {
231    match stmt {
232        Statement::WordCall { name, .. } => {
233            // Only track calls to user-defined words
234            if known_words.contains(name) {
235                calls.insert(name.clone());
236            }
237        }
238        Statement::If {
239            then_branch,
240            else_branch,
241        } => {
242            for s in then_branch {
243                extract_calls_from_statement(s, known_words, calls);
244            }
245            if let Some(else_stmts) = else_branch {
246                for s in else_stmts {
247                    extract_calls_from_statement(s, known_words, calls);
248                }
249            }
250        }
251        Statement::Quotation { body, .. } => {
252            for s in body {
253                extract_calls_from_statement(s, known_words, calls);
254            }
255        }
256        Statement::Match { arms } => {
257            for arm in arms {
258                for s in &arm.body {
259                    extract_calls_from_statement(s, known_words, calls);
260                }
261            }
262        }
263        // Literals don't contain calls
264        Statement::IntLiteral(_)
265        | Statement::FloatLiteral(_)
266        | Statement::BoolLiteral(_)
267        | Statement::StringLiteral(_)
268        | Statement::Symbol(_) => {}
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::ast::WordDef;
276
277    fn make_word(name: &str, calls: Vec<&str>) -> WordDef {
278        let body = calls
279            .into_iter()
280            .map(|c| Statement::WordCall {
281                name: c.to_string(),
282                span: None,
283            })
284            .collect();
285        WordDef {
286            name: name.to_string(),
287            effect: None,
288            body,
289            source: None,
290            allowed_lints: vec![],
291        }
292    }
293
294    #[test]
295    fn test_no_recursion() {
296        let program = Program {
297            includes: vec![],
298            unions: vec![],
299            words: vec![
300                make_word("foo", vec!["bar"]),
301                make_word("bar", vec![]),
302                make_word("baz", vec!["foo"]),
303            ],
304        };
305
306        let graph = CallGraph::build(&program);
307        assert!(!graph.is_recursive("foo"));
308        assert!(!graph.is_recursive("bar"));
309        assert!(!graph.is_recursive("baz"));
310        assert!(graph.recursive_cycles().is_empty());
311    }
312
313    #[test]
314    fn test_direct_recursion() {
315        let program = Program {
316            includes: vec![],
317            unions: vec![],
318            words: vec![
319                make_word("countdown", vec!["countdown"]),
320                make_word("helper", vec![]),
321            ],
322        };
323
324        let graph = CallGraph::build(&program);
325        assert!(graph.is_recursive("countdown"));
326        assert!(!graph.is_recursive("helper"));
327        assert_eq!(graph.recursive_cycles().len(), 1);
328    }
329
330    #[test]
331    fn test_mutual_recursion_pair() {
332        let program = Program {
333            includes: vec![],
334            unions: vec![],
335            words: vec![
336                make_word("ping", vec!["pong"]),
337                make_word("pong", vec!["ping"]),
338            ],
339        };
340
341        let graph = CallGraph::build(&program);
342        assert!(graph.is_recursive("ping"));
343        assert!(graph.is_recursive("pong"));
344        assert!(graph.are_mutually_recursive("ping", "pong"));
345        assert_eq!(graph.recursive_cycles().len(), 1);
346        assert_eq!(graph.recursive_cycles()[0].len(), 2);
347    }
348
349    #[test]
350    fn test_mutual_recursion_triple() {
351        let program = Program {
352            includes: vec![],
353            unions: vec![],
354            words: vec![
355                make_word("a", vec!["b"]),
356                make_word("b", vec!["c"]),
357                make_word("c", vec!["a"]),
358            ],
359        };
360
361        let graph = CallGraph::build(&program);
362        assert!(graph.is_recursive("a"));
363        assert!(graph.is_recursive("b"));
364        assert!(graph.is_recursive("c"));
365        assert!(graph.are_mutually_recursive("a", "b"));
366        assert!(graph.are_mutually_recursive("b", "c"));
367        assert!(graph.are_mutually_recursive("a", "c"));
368        assert_eq!(graph.recursive_cycles().len(), 1);
369        assert_eq!(graph.recursive_cycles()[0].len(), 3);
370    }
371
372    #[test]
373    fn test_multiple_independent_cycles() {
374        let program = Program {
375            includes: vec![],
376            unions: vec![],
377            words: vec![
378                // Cycle 1: ping <-> pong
379                make_word("ping", vec!["pong"]),
380                make_word("pong", vec!["ping"]),
381                // Cycle 2: even <-> odd
382                make_word("even", vec!["odd"]),
383                make_word("odd", vec!["even"]),
384                // Non-recursive
385                make_word("main", vec!["ping", "even"]),
386            ],
387        };
388
389        let graph = CallGraph::build(&program);
390        assert!(graph.is_recursive("ping"));
391        assert!(graph.is_recursive("pong"));
392        assert!(graph.is_recursive("even"));
393        assert!(graph.is_recursive("odd"));
394        assert!(!graph.is_recursive("main"));
395
396        assert!(graph.are_mutually_recursive("ping", "pong"));
397        assert!(graph.are_mutually_recursive("even", "odd"));
398        assert!(!graph.are_mutually_recursive("ping", "even"));
399
400        assert_eq!(graph.recursive_cycles().len(), 2);
401    }
402
403    #[test]
404    fn test_calls_to_unknown_words() {
405        // Calls to builtins or external words should be ignored
406        let program = Program {
407            includes: vec![],
408            unions: vec![],
409            words: vec![make_word("foo", vec!["dup", "drop", "unknown_builtin"])],
410        };
411
412        let graph = CallGraph::build(&program);
413        assert!(!graph.is_recursive("foo"));
414        // Callees should only include known words
415        assert!(graph.callees("foo").unwrap().is_empty());
416    }
417
418    #[test]
419    fn test_cycle_with_builtins_interspersed() {
420        // Cycles should be detected even when builtins are called between user words
421        // e.g., : foo dup drop bar ;  : bar swap foo ;
422        let program = Program {
423            includes: vec![],
424            unions: vec![],
425            words: vec![
426                make_word("foo", vec!["dup", "drop", "bar"]),
427                make_word("bar", vec!["swap", "foo"]),
428            ],
429        };
430
431        let graph = CallGraph::build(&program);
432        // foo and bar should still form a cycle despite builtin calls
433        assert!(graph.is_recursive("foo"));
434        assert!(graph.is_recursive("bar"));
435        assert!(graph.are_mutually_recursive("foo", "bar"));
436
437        // Builtins should not appear in callees
438        let foo_callees = graph.callees("foo").unwrap();
439        assert!(foo_callees.contains("bar"));
440        assert!(!foo_callees.contains("dup"));
441        assert!(!foo_callees.contains("drop"));
442    }
443
444    #[test]
445    fn test_cycle_through_quotation() {
446        // Calls inside quotations should be detected
447        // e.g., : foo [ bar ] call ;  : bar foo ;
448        use crate::ast::Statement;
449
450        let program = Program {
451            includes: vec![],
452            unions: vec![],
453            words: vec![
454                WordDef {
455                    name: "foo".to_string(),
456                    effect: None,
457                    body: vec![
458                        Statement::Quotation {
459                            id: 0,
460                            body: vec![Statement::WordCall {
461                                name: "bar".to_string(),
462                                span: None,
463                            }],
464                            span: None,
465                        },
466                        Statement::WordCall {
467                            name: "call".to_string(),
468                            span: None,
469                        },
470                    ],
471                    source: None,
472                    allowed_lints: vec![],
473                },
474                make_word("bar", vec!["foo"]),
475            ],
476        };
477
478        let graph = CallGraph::build(&program);
479        // foo calls bar (inside quotation), bar calls foo
480        assert!(graph.is_recursive("foo"));
481        assert!(graph.is_recursive("bar"));
482        assert!(graph.are_mutually_recursive("foo", "bar"));
483    }
484
485    #[test]
486    fn test_cycle_through_if_branch() {
487        // Calls inside if branches should be detected
488        use crate::ast::Statement;
489
490        let program = Program {
491            includes: vec![],
492            unions: vec![],
493            words: vec![
494                WordDef {
495                    name: "even".to_string(),
496                    effect: None,
497                    body: vec![Statement::If {
498                        then_branch: vec![],
499                        else_branch: Some(vec![Statement::WordCall {
500                            name: "odd".to_string(),
501                            span: None,
502                        }]),
503                    }],
504                    source: None,
505                    allowed_lints: vec![],
506                },
507                WordDef {
508                    name: "odd".to_string(),
509                    effect: None,
510                    body: vec![Statement::If {
511                        then_branch: vec![],
512                        else_branch: Some(vec![Statement::WordCall {
513                            name: "even".to_string(),
514                            span: None,
515                        }]),
516                    }],
517                    source: None,
518                    allowed_lints: vec![],
519                },
520            ],
521        };
522
523        let graph = CallGraph::build(&program);
524        assert!(graph.is_recursive("even"));
525        assert!(graph.is_recursive("odd"));
526        assert!(graph.are_mutually_recursive("even", "odd"));
527    }
528}