Skip to main content

seqc/
call_graph.rs

1//! Call graph analysis for detecting mutual recursion
2//!
3//! This module builds a call graph from a Seq program and detects
4//! strongly connected components (SCCs) to identify mutual recursion cycles.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! let call_graph = CallGraph::build(&program);
10//! let cycles = call_graph.recursive_cycles();
11//! ```
12//!
13//! # Primary Use Cases
14//!
15//! 1. **Type checker divergence detection**: The type checker uses the call graph
16//!    to identify mutually recursive tail calls, enabling correct type inference
17//!    for patterns like even/odd that would otherwise require branch unification.
18//!
19//! 2. **Future optimizations**: The call graph infrastructure can support dead code
20//!    detection, inlining decisions, and diagnostic tools.
21//!
22//! # Implementation Details
23//!
24//! - **Algorithm**: Tarjan's SCC algorithm, O(V + E) time complexity
25//! - **Builtins**: Calls to builtins/external words are excluded from the graph
26//!   (they don't affect recursion detection since they always return)
27//! - **Quotations**: Calls within quotations are included in the analysis
28//! - **Match arms**: Calls within match arms are included in the analysis
29//!
30//! # Note on Tail Call Optimization
31//!
32//! The existing codegen already emits `musttail` for all tail calls to user-defined
33//! words (see `codegen/statements.rs`). This means mutual TCO works automatically
34//! without needing explicit call graph checks in codegen. The call graph is primarily
35//! used for type checking, not for enabling TCO.
36
37use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40/// A call graph representing which words call which other words.
41#[derive(Debug, Clone)]
42pub struct CallGraph {
43    /// Map from word name to the set of words it calls
44    edges: HashMap<String, HashSet<String>>,
45    /// All word names in the program
46    words: HashSet<String>,
47    /// Strongly connected components with more than one member (mutual recursion)
48    /// or single members that call themselves (direct recursion)
49    recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53    /// Build a call graph from a program.
54    ///
55    /// This extracts all word-to-word call relationships, including calls
56    /// within quotations, if branches, and match arms.
57    pub fn build(program: &Program) -> Self {
58        let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59        let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61        for word in &program.words {
62            let callees = extract_calls(&word.body, &words);
63            edges.insert(word.name.clone(), callees);
64        }
65
66        let mut graph = CallGraph {
67            edges,
68            words,
69            recursive_sccs: Vec::new(),
70        };
71
72        // Compute SCCs and identify recursive cycles
73        graph.recursive_sccs = graph.find_sccs();
74
75        graph
76    }
77
78    /// Check if a word is part of any recursive cycle (direct or mutual).
79    pub fn is_recursive(&self, word: &str) -> bool {
80        self.recursive_sccs.iter().any(|scc| scc.contains(word))
81    }
82
83    /// Check if two words are in the same recursive cycle (mutually recursive).
84    pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85        self.recursive_sccs
86            .iter()
87            .any(|scc| scc.contains(word1) && scc.contains(word2))
88    }
89
90    /// Get all recursive cycles (SCCs with recursion).
91    pub fn recursive_cycles(&self) -> &[HashSet<String>] {
92        &self.recursive_sccs
93    }
94
95    /// Get the words that a given word calls.
96    pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
97        self.edges.get(word)
98    }
99
100    /// Find strongly connected components using Tarjan's algorithm.
101    ///
102    /// Returns only SCCs that represent recursion:
103    /// - Multi-word SCCs (mutual recursion)
104    /// - Single-word SCCs where the word calls itself (direct recursion)
105    fn find_sccs(&self) -> Vec<HashSet<String>> {
106        let mut index_counter = 0;
107        let mut stack: Vec<String> = Vec::new();
108        let mut on_stack: HashSet<String> = HashSet::new();
109        let mut indices: HashMap<String, usize> = HashMap::new();
110        let mut lowlinks: HashMap<String, usize> = HashMap::new();
111        let mut sccs: Vec<HashSet<String>> = Vec::new();
112
113        for word in &self.words {
114            if !indices.contains_key(word) {
115                self.tarjan_visit(
116                    word,
117                    &mut index_counter,
118                    &mut stack,
119                    &mut on_stack,
120                    &mut indices,
121                    &mut lowlinks,
122                    &mut sccs,
123                );
124            }
125        }
126
127        // Filter to only recursive SCCs
128        sccs.into_iter()
129            .filter(|scc| {
130                if scc.len() > 1 {
131                    // Multi-word SCC = mutual recursion
132                    true
133                } else if scc.len() == 1 {
134                    // Single-word SCC: check if it calls itself
135                    let word = scc.iter().next().unwrap();
136                    self.edges
137                        .get(word)
138                        .map(|callees| callees.contains(word))
139                        .unwrap_or(false)
140                } else {
141                    false
142                }
143            })
144            .collect()
145    }
146
147    /// Tarjan's algorithm recursive visit.
148    #[allow(clippy::too_many_arguments)]
149    fn tarjan_visit(
150        &self,
151        word: &str,
152        index_counter: &mut usize,
153        stack: &mut Vec<String>,
154        on_stack: &mut HashSet<String>,
155        indices: &mut HashMap<String, usize>,
156        lowlinks: &mut HashMap<String, usize>,
157        sccs: &mut Vec<HashSet<String>>,
158    ) {
159        let index = *index_counter;
160        *index_counter += 1;
161        indices.insert(word.to_string(), index);
162        lowlinks.insert(word.to_string(), index);
163        stack.push(word.to_string());
164        on_stack.insert(word.to_string());
165
166        // Visit all callees
167        if let Some(callees) = self.edges.get(word) {
168            for callee in callees {
169                if !self.words.contains(callee) {
170                    // External word (builtin), skip
171                    continue;
172                }
173                if !indices.contains_key(callee) {
174                    // Not yet visited
175                    self.tarjan_visit(
176                        callee,
177                        index_counter,
178                        stack,
179                        on_stack,
180                        indices,
181                        lowlinks,
182                        sccs,
183                    );
184                    let callee_lowlink = *lowlinks.get(callee).unwrap();
185                    let word_lowlink = lowlinks.get_mut(word).unwrap();
186                    *word_lowlink = (*word_lowlink).min(callee_lowlink);
187                } else if on_stack.contains(callee) {
188                    // Callee is on stack, part of current SCC
189                    let callee_index = *indices.get(callee).unwrap();
190                    let word_lowlink = lowlinks.get_mut(word).unwrap();
191                    *word_lowlink = (*word_lowlink).min(callee_index);
192                }
193            }
194        }
195
196        // If word is a root node, pop the SCC
197        if lowlinks.get(word) == indices.get(word) {
198            let mut scc = HashSet::new();
199            loop {
200                let w = stack.pop().unwrap();
201                on_stack.remove(&w);
202                scc.insert(w.clone());
203                if w == word {
204                    break;
205                }
206            }
207            sccs.push(scc);
208        }
209    }
210}
211
212/// Extract all word calls from a list of statements.
213///
214/// This recursively descends into quotations, if branches, and match arms.
215fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
216    let mut calls = HashSet::new();
217
218    for stmt in statements {
219        extract_calls_from_statement(stmt, known_words, &mut calls);
220    }
221
222    calls
223}
224
225/// Extract word calls from a single statement.
226fn extract_calls_from_statement(
227    stmt: &Statement,
228    known_words: &HashSet<String>,
229    calls: &mut HashSet<String>,
230) {
231    match stmt {
232        Statement::WordCall { name, .. } => {
233            // Only track calls to user-defined words
234            if known_words.contains(name) {
235                calls.insert(name.clone());
236            }
237        }
238        Statement::If {
239            then_branch,
240            else_branch,
241            span: _,
242        } => {
243            for s in then_branch {
244                extract_calls_from_statement(s, known_words, calls);
245            }
246            if let Some(else_stmts) = else_branch {
247                for s in else_stmts {
248                    extract_calls_from_statement(s, known_words, calls);
249                }
250            }
251        }
252        Statement::Quotation { body, .. } => {
253            for s in body {
254                extract_calls_from_statement(s, known_words, calls);
255            }
256        }
257        Statement::Match { arms, span: _ } => {
258            for arm in arms {
259                for s in &arm.body {
260                    extract_calls_from_statement(s, known_words, calls);
261                }
262            }
263        }
264        // Literals don't contain calls
265        Statement::IntLiteral(_)
266        | Statement::FloatLiteral(_)
267        | Statement::BoolLiteral(_)
268        | Statement::StringLiteral(_)
269        | Statement::Symbol(_) => {}
270    }
271}
272
273#[cfg(test)]
274mod tests;