Skip to main content

seqc/
call_graph.rs

1//! Call graph analysis for detecting mutual recursion
2//!
3//! This module builds a call graph from a Seq program and detects
4//! strongly connected components (SCCs) to identify mutual recursion cycles.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! let call_graph = CallGraph::build(&program);
10//! let cycles = call_graph.recursive_cycles();
11//! ```
12//!
13//! # Primary Use Cases
14//!
15//! 1. **Type checker divergence detection**: The type checker uses the call graph
16//!    to identify mutually recursive tail calls, enabling correct type inference
17//!    for patterns like even/odd that would otherwise require branch unification.
18//!
19//! 2. **Future optimizations**: The call graph infrastructure can support dead code
20//!    detection, inlining decisions, and diagnostic tools.
21//!
22//! # Implementation Details
23//!
24//! - **Algorithm**: Tarjan's SCC algorithm, O(V + E) time complexity
25//! - **Builtins**: Calls to builtins/external words are excluded from the graph
26//!   (they don't affect recursion detection since they always return)
27//! - **Quotations**: Calls within quotations are included in the analysis
28//! - **Match arms**: Calls within match arms are included in the analysis
29//!
30//! # Note on Tail Call Optimization
31//!
32//! The existing codegen already emits `musttail` for all tail calls to user-defined
33//! words (see `codegen/statements.rs`). This means mutual TCO works automatically
34//! without needing explicit call graph checks in codegen. The call graph is primarily
35//! used for type checking, not for enabling TCO.
36
37use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40/// A call graph representing which words call which other words.
41#[derive(Debug, Clone)]
42pub struct CallGraph {
43    /// Map from word name to the set of words it calls
44    edges: HashMap<String, HashSet<String>>,
45    /// All word names in the program
46    words: HashSet<String>,
47    /// Strongly connected components with more than one member (mutual recursion)
48    /// or single members that call themselves (direct recursion)
49    recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53    /// Build a call graph from a program.
54    ///
55    /// This extracts all word-to-word call relationships, including calls
56    /// within quotations, if branches, and match arms.
57    pub fn build(program: &Program) -> Self {
58        let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59        let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61        for word in &program.words {
62            let callees = extract_calls(&word.body, &words);
63            edges.insert(word.name.clone(), callees);
64        }
65
66        let mut graph = CallGraph {
67            edges,
68            words,
69            recursive_sccs: Vec::new(),
70        };
71
72        // Compute SCCs and identify recursive cycles
73        graph.recursive_sccs = graph.find_sccs();
74
75        graph
76    }
77
78    /// Check if a word is part of any recursive cycle (direct or mutual).
79    pub fn is_recursive(&self, word: &str) -> bool {
80        self.recursive_sccs.iter().any(|scc| scc.contains(word))
81    }
82
83    /// Check if two words are in the same recursive cycle (mutually recursive).
84    pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85        self.recursive_sccs
86            .iter()
87            .any(|scc| scc.contains(word1) && scc.contains(word2))
88    }
89
90    /// Get the recursive cycle containing a word, if any.
91    pub fn get_cycle(&self, word: &str) -> Option<&HashSet<String>> {
92        self.recursive_sccs.iter().find(|scc| scc.contains(word))
93    }
94
95    /// Get all recursive cycles (SCCs with recursion).
96    pub fn recursive_cycles(&self) -> &[HashSet<String>] {
97        &self.recursive_sccs
98    }
99
100    /// Get the words that a given word calls.
101    pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
102        self.edges.get(word)
103    }
104
105    /// Find strongly connected components using Tarjan's algorithm.
106    ///
107    /// Returns only SCCs that represent recursion:
108    /// - Multi-word SCCs (mutual recursion)
109    /// - Single-word SCCs where the word calls itself (direct recursion)
110    fn find_sccs(&self) -> Vec<HashSet<String>> {
111        let mut index_counter = 0;
112        let mut stack: Vec<String> = Vec::new();
113        let mut on_stack: HashSet<String> = HashSet::new();
114        let mut indices: HashMap<String, usize> = HashMap::new();
115        let mut lowlinks: HashMap<String, usize> = HashMap::new();
116        let mut sccs: Vec<HashSet<String>> = Vec::new();
117
118        for word in &self.words {
119            if !indices.contains_key(word) {
120                self.tarjan_visit(
121                    word,
122                    &mut index_counter,
123                    &mut stack,
124                    &mut on_stack,
125                    &mut indices,
126                    &mut lowlinks,
127                    &mut sccs,
128                );
129            }
130        }
131
132        // Filter to only recursive SCCs
133        sccs.into_iter()
134            .filter(|scc| {
135                if scc.len() > 1 {
136                    // Multi-word SCC = mutual recursion
137                    true
138                } else if scc.len() == 1 {
139                    // Single-word SCC: check if it calls itself
140                    let word = scc.iter().next().unwrap();
141                    self.edges
142                        .get(word)
143                        .map(|callees| callees.contains(word))
144                        .unwrap_or(false)
145                } else {
146                    false
147                }
148            })
149            .collect()
150    }
151
152    /// Tarjan's algorithm recursive visit.
153    #[allow(clippy::too_many_arguments)]
154    fn tarjan_visit(
155        &self,
156        word: &str,
157        index_counter: &mut usize,
158        stack: &mut Vec<String>,
159        on_stack: &mut HashSet<String>,
160        indices: &mut HashMap<String, usize>,
161        lowlinks: &mut HashMap<String, usize>,
162        sccs: &mut Vec<HashSet<String>>,
163    ) {
164        let index = *index_counter;
165        *index_counter += 1;
166        indices.insert(word.to_string(), index);
167        lowlinks.insert(word.to_string(), index);
168        stack.push(word.to_string());
169        on_stack.insert(word.to_string());
170
171        // Visit all callees
172        if let Some(callees) = self.edges.get(word) {
173            for callee in callees {
174                if !self.words.contains(callee) {
175                    // External word (builtin), skip
176                    continue;
177                }
178                if !indices.contains_key(callee) {
179                    // Not yet visited
180                    self.tarjan_visit(
181                        callee,
182                        index_counter,
183                        stack,
184                        on_stack,
185                        indices,
186                        lowlinks,
187                        sccs,
188                    );
189                    let callee_lowlink = *lowlinks.get(callee).unwrap();
190                    let word_lowlink = lowlinks.get_mut(word).unwrap();
191                    *word_lowlink = (*word_lowlink).min(callee_lowlink);
192                } else if on_stack.contains(callee) {
193                    // Callee is on stack, part of current SCC
194                    let callee_index = *indices.get(callee).unwrap();
195                    let word_lowlink = lowlinks.get_mut(word).unwrap();
196                    *word_lowlink = (*word_lowlink).min(callee_index);
197                }
198            }
199        }
200
201        // If word is a root node, pop the SCC
202        if lowlinks.get(word) == indices.get(word) {
203            let mut scc = HashSet::new();
204            loop {
205                let w = stack.pop().unwrap();
206                on_stack.remove(&w);
207                scc.insert(w.clone());
208                if w == word {
209                    break;
210                }
211            }
212            sccs.push(scc);
213        }
214    }
215}
216
217/// Extract all word calls from a list of statements.
218///
219/// This recursively descends into quotations, if branches, and match arms.
220fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
221    let mut calls = HashSet::new();
222
223    for stmt in statements {
224        extract_calls_from_statement(stmt, known_words, &mut calls);
225    }
226
227    calls
228}
229
230/// Extract word calls from a single statement.
231fn extract_calls_from_statement(
232    stmt: &Statement,
233    known_words: &HashSet<String>,
234    calls: &mut HashSet<String>,
235) {
236    match stmt {
237        Statement::WordCall { name, .. } => {
238            // Only track calls to user-defined words
239            if known_words.contains(name) {
240                calls.insert(name.clone());
241            }
242        }
243        Statement::If {
244            then_branch,
245            else_branch,
246        } => {
247            for s in then_branch {
248                extract_calls_from_statement(s, known_words, calls);
249            }
250            if let Some(else_stmts) = else_branch {
251                for s in else_stmts {
252                    extract_calls_from_statement(s, known_words, calls);
253                }
254            }
255        }
256        Statement::Quotation { body, .. } => {
257            for s in body {
258                extract_calls_from_statement(s, known_words, calls);
259            }
260        }
261        Statement::Match { arms } => {
262            for arm in arms {
263                for s in &arm.body {
264                    extract_calls_from_statement(s, known_words, calls);
265                }
266            }
267        }
268        // Literals don't contain calls
269        Statement::IntLiteral(_)
270        | Statement::FloatLiteral(_)
271        | Statement::BoolLiteral(_)
272        | Statement::StringLiteral(_)
273        | Statement::Symbol(_) => {}
274    }
275}
276
277/// Information about tail calls for mutual TCO optimization.
278///
279/// # Current Status
280///
281/// This struct is currently **infrastructure for future use**. The existing codegen
282/// already emits `musttail` for all tail calls to user-defined words, so mutual TCO
283/// works without explicit call graph checks.
284///
285/// Potential future uses:
286/// - Selective TCO (only optimize detected recursive cycles)
287/// - Diagnostic tools (show which words are mutually recursive)
288/// - Dead code detection (unreachable words in non-recursive paths)
289#[derive(Debug, Clone)]
290#[allow(dead_code)] // Infrastructure for future optimizations
291pub struct TailCallInfo {
292    /// Words that are in a recursive cycle and should get mutual TCO
293    pub recursive_words: HashSet<String>,
294}
295
296impl TailCallInfo {
297    /// Build tail call info from a call graph.
298    pub fn from_call_graph(graph: &CallGraph) -> Self {
299        let mut recursive_words = HashSet::new();
300        for scc in graph.recursive_cycles() {
301            recursive_words.extend(scc.iter().cloned());
302        }
303        TailCallInfo { recursive_words }
304    }
305
306    /// Check if a call from `caller` to `callee` is between mutually recursive words.
307    ///
308    /// Returns true if both are in the same recursive cycle.
309    ///
310    /// # Note
311    ///
312    /// Currently unused in codegen since all user-word tail calls get `musttail`.
313    /// Kept as infrastructure for potential future selective optimization.
314    #[allow(dead_code)] // Infrastructure for future optimizations
315    pub fn should_use_musttail(&self, caller: &str, callee: &str) -> bool {
316        self.recursive_words.contains(caller) && self.recursive_words.contains(callee)
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::ast::WordDef;
324
325    fn make_word(name: &str, calls: Vec<&str>) -> WordDef {
326        let body = calls
327            .into_iter()
328            .map(|c| Statement::WordCall {
329                name: c.to_string(),
330                span: None,
331            })
332            .collect();
333        WordDef {
334            name: name.to_string(),
335            effect: None,
336            body,
337            source: None,
338            allowed_lints: vec![],
339        }
340    }
341
342    #[test]
343    fn test_no_recursion() {
344        let program = Program {
345            includes: vec![],
346            unions: vec![],
347            words: vec![
348                make_word("foo", vec!["bar"]),
349                make_word("bar", vec![]),
350                make_word("baz", vec!["foo"]),
351            ],
352        };
353
354        let graph = CallGraph::build(&program);
355        assert!(!graph.is_recursive("foo"));
356        assert!(!graph.is_recursive("bar"));
357        assert!(!graph.is_recursive("baz"));
358        assert!(graph.recursive_cycles().is_empty());
359    }
360
361    #[test]
362    fn test_direct_recursion() {
363        let program = Program {
364            includes: vec![],
365            unions: vec![],
366            words: vec![
367                make_word("countdown", vec!["countdown"]),
368                make_word("helper", vec![]),
369            ],
370        };
371
372        let graph = CallGraph::build(&program);
373        assert!(graph.is_recursive("countdown"));
374        assert!(!graph.is_recursive("helper"));
375        assert_eq!(graph.recursive_cycles().len(), 1);
376    }
377
378    #[test]
379    fn test_mutual_recursion_pair() {
380        let program = Program {
381            includes: vec![],
382            unions: vec![],
383            words: vec![
384                make_word("ping", vec!["pong"]),
385                make_word("pong", vec!["ping"]),
386            ],
387        };
388
389        let graph = CallGraph::build(&program);
390        assert!(graph.is_recursive("ping"));
391        assert!(graph.is_recursive("pong"));
392        assert!(graph.are_mutually_recursive("ping", "pong"));
393        assert_eq!(graph.recursive_cycles().len(), 1);
394        assert_eq!(graph.recursive_cycles()[0].len(), 2);
395    }
396
397    #[test]
398    fn test_mutual_recursion_triple() {
399        let program = Program {
400            includes: vec![],
401            unions: vec![],
402            words: vec![
403                make_word("a", vec!["b"]),
404                make_word("b", vec!["c"]),
405                make_word("c", vec!["a"]),
406            ],
407        };
408
409        let graph = CallGraph::build(&program);
410        assert!(graph.is_recursive("a"));
411        assert!(graph.is_recursive("b"));
412        assert!(graph.is_recursive("c"));
413        assert!(graph.are_mutually_recursive("a", "b"));
414        assert!(graph.are_mutually_recursive("b", "c"));
415        assert!(graph.are_mutually_recursive("a", "c"));
416        assert_eq!(graph.recursive_cycles().len(), 1);
417        assert_eq!(graph.recursive_cycles()[0].len(), 3);
418    }
419
420    #[test]
421    fn test_multiple_independent_cycles() {
422        let program = Program {
423            includes: vec![],
424            unions: vec![],
425            words: vec![
426                // Cycle 1: ping <-> pong
427                make_word("ping", vec!["pong"]),
428                make_word("pong", vec!["ping"]),
429                // Cycle 2: even <-> odd
430                make_word("even", vec!["odd"]),
431                make_word("odd", vec!["even"]),
432                // Non-recursive
433                make_word("main", vec!["ping", "even"]),
434            ],
435        };
436
437        let graph = CallGraph::build(&program);
438        assert!(graph.is_recursive("ping"));
439        assert!(graph.is_recursive("pong"));
440        assert!(graph.is_recursive("even"));
441        assert!(graph.is_recursive("odd"));
442        assert!(!graph.is_recursive("main"));
443
444        assert!(graph.are_mutually_recursive("ping", "pong"));
445        assert!(graph.are_mutually_recursive("even", "odd"));
446        assert!(!graph.are_mutually_recursive("ping", "even"));
447
448        assert_eq!(graph.recursive_cycles().len(), 2);
449    }
450
451    #[test]
452    fn test_calls_to_unknown_words() {
453        // Calls to builtins or external words should be ignored
454        let program = Program {
455            includes: vec![],
456            unions: vec![],
457            words: vec![make_word("foo", vec!["dup", "drop", "unknown_builtin"])],
458        };
459
460        let graph = CallGraph::build(&program);
461        assert!(!graph.is_recursive("foo"));
462        // Callees should only include known words
463        assert!(graph.callees("foo").unwrap().is_empty());
464    }
465
466    #[test]
467    fn test_tail_call_info() {
468        let program = Program {
469            includes: vec![],
470            unions: vec![],
471            words: vec![
472                make_word("ping", vec!["pong"]),
473                make_word("pong", vec!["ping"]),
474                make_word("helper", vec![]),
475            ],
476        };
477
478        let graph = CallGraph::build(&program);
479        let info = TailCallInfo::from_call_graph(&graph);
480
481        assert!(info.should_use_musttail("ping", "pong"));
482        assert!(info.should_use_musttail("pong", "ping"));
483        assert!(!info.should_use_musttail("helper", "ping"));
484        assert!(!info.should_use_musttail("ping", "helper"));
485    }
486
487    #[test]
488    fn test_cycle_with_builtins_interspersed() {
489        // Cycles should be detected even when builtins are called between user words
490        // e.g., : foo dup drop bar ;  : bar swap foo ;
491        let program = Program {
492            includes: vec![],
493            unions: vec![],
494            words: vec![
495                make_word("foo", vec!["dup", "drop", "bar"]),
496                make_word("bar", vec!["swap", "foo"]),
497            ],
498        };
499
500        let graph = CallGraph::build(&program);
501        // foo and bar should still form a cycle despite builtin calls
502        assert!(graph.is_recursive("foo"));
503        assert!(graph.is_recursive("bar"));
504        assert!(graph.are_mutually_recursive("foo", "bar"));
505
506        // Builtins should not appear in callees
507        let foo_callees = graph.callees("foo").unwrap();
508        assert!(foo_callees.contains("bar"));
509        assert!(!foo_callees.contains("dup"));
510        assert!(!foo_callees.contains("drop"));
511    }
512
513    #[test]
514    fn test_cycle_through_quotation() {
515        // Calls inside quotations should be detected
516        // e.g., : foo [ bar ] call ;  : bar foo ;
517        use crate::ast::Statement;
518
519        let program = Program {
520            includes: vec![],
521            unions: vec![],
522            words: vec![
523                WordDef {
524                    name: "foo".to_string(),
525                    effect: None,
526                    body: vec![
527                        Statement::Quotation {
528                            id: 0,
529                            body: vec![Statement::WordCall {
530                                name: "bar".to_string(),
531                                span: None,
532                            }],
533                            span: None,
534                        },
535                        Statement::WordCall {
536                            name: "call".to_string(),
537                            span: None,
538                        },
539                    ],
540                    source: None,
541                    allowed_lints: vec![],
542                },
543                make_word("bar", vec!["foo"]),
544            ],
545        };
546
547        let graph = CallGraph::build(&program);
548        // foo calls bar (inside quotation), bar calls foo
549        assert!(graph.is_recursive("foo"));
550        assert!(graph.is_recursive("bar"));
551        assert!(graph.are_mutually_recursive("foo", "bar"));
552    }
553
554    #[test]
555    fn test_cycle_through_if_branch() {
556        // Calls inside if branches should be detected
557        use crate::ast::Statement;
558
559        let program = Program {
560            includes: vec![],
561            unions: vec![],
562            words: vec![
563                WordDef {
564                    name: "even".to_string(),
565                    effect: None,
566                    body: vec![Statement::If {
567                        then_branch: vec![],
568                        else_branch: Some(vec![Statement::WordCall {
569                            name: "odd".to_string(),
570                            span: None,
571                        }]),
572                    }],
573                    source: None,
574                    allowed_lints: vec![],
575                },
576                WordDef {
577                    name: "odd".to_string(),
578                    effect: None,
579                    body: vec![Statement::If {
580                        then_branch: vec![],
581                        else_branch: Some(vec![Statement::WordCall {
582                            name: "even".to_string(),
583                            span: None,
584                        }]),
585                    }],
586                    source: None,
587                    allowed_lints: vec![],
588                },
589            ],
590        };
591
592        let graph = CallGraph::build(&program);
593        assert!(graph.is_recursive("even"));
594        assert!(graph.is_recursive("odd"));
595        assert!(graph.are_mutually_recursive("even", "odd"));
596    }
597}