Skip to main content

tensorlogic_adapters/
rule_deps.rs

1//! Rule dependency graph for TensorLogic.
2//!
3//! Builds a directed graph where rules depend on predicates and predicates
4//! are defined by rules. Enables cycle detection, stratification, SCC
5//! computation, and transitive dependency analysis.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use crate::SymbolTable;
10
11// ─────────────────────────────────────────────────────────────────────────────
12// DepNode
13// ─────────────────────────────────────────────────────────────────────────────
14
15/// A node in the dependency graph — either a named rule or a named predicate.
16#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub enum DepNode {
18    /// A rule identified by its name.
19    Rule(String),
20    /// A predicate identified by its name.
21    Predicate(String),
22}
23
24impl DepNode {
25    /// The name string inside the variant.
26    pub fn name(&self) -> &str {
27        match self {
28            DepNode::Rule(n) | DepNode::Predicate(n) => n.as_str(),
29        }
30    }
31
32    /// Returns `true` when this node is a rule.
33    pub fn is_rule(&self) -> bool {
34        matches!(self, DepNode::Rule(_))
35    }
36
37    /// Returns `true` when this node is a predicate.
38    pub fn is_predicate(&self) -> bool {
39        matches!(self, DepNode::Predicate(_))
40    }
41}
42
43impl std::fmt::Display for DepNode {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            DepNode::Rule(n) => write!(f, "Rule({n})"),
47            DepNode::Predicate(n) => write!(f, "Pred({n})"),
48        }
49    }
50}
51
52// ─────────────────────────────────────────────────────────────────────────────
53// DepEdge
54// ─────────────────────────────────────────────────────────────────────────────
55
56/// The semantics of a directed edge in the dependency graph.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub enum DepEdge {
59    /// Rule uses the predicate positively (in head or positive body literal).
60    Positive,
61    /// Rule uses the predicate under negation.
62    Negative,
63    /// Rule *defines* (writes to) the predicate — i.e. the head.
64    Defines,
65}
66
67impl std::fmt::Display for DepEdge {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            DepEdge::Positive => write!(f, "+"),
71            DepEdge::Negative => write!(f, "−"),
72            DepEdge::Defines => write!(f, "def"),
73        }
74    }
75}
76
77// ─────────────────────────────────────────────────────────────────────────────
78// RuleDependencyGraph
79// ─────────────────────────────────────────────────────────────────────────────
80
81/// Directed graph capturing dependencies between rules and predicates.
82#[derive(Debug, Clone)]
83pub struct RuleDependencyGraph {
84    /// Adjacency list: node → list of (neighbour, edge_type).
85    edges: HashMap<DepNode, Vec<(DepNode, DepEdge)>>,
86    /// Full node set (includes nodes with no outgoing edges).
87    nodes: HashSet<DepNode>,
88}
89
90impl Default for RuleDependencyGraph {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96impl RuleDependencyGraph {
97    /// Create an empty graph.
98    pub fn new() -> Self {
99        RuleDependencyGraph {
100            edges: HashMap::new(),
101            nodes: HashSet::new(),
102        }
103    }
104
105    // ── Mutation ──────────────────────────────────────────────────────────────
106
107    /// Insert a node (idempotent).
108    pub fn add_node(&mut self, node: DepNode) {
109        self.nodes.insert(node.clone());
110        self.edges.entry(node).or_default();
111    }
112
113    /// Insert a directed edge from `from` to `to` with edge type `edge`.
114    /// Both endpoints are automatically added as nodes.
115    pub fn add_edge(&mut self, from: DepNode, to: DepNode, edge: DepEdge) {
116        self.add_node(from.clone());
117        self.add_node(to.clone());
118        self.edges.entry(from).or_default().push((to, edge));
119    }
120
121    // ── Construction ──────────────────────────────────────────────────────────
122
123    /// Build a dependency graph from a `SymbolTable`.
124    ///
125    /// Because `SymbolTable` stores predicates (not first-class rules with
126    /// heads/bodies), this method treats each predicate as both a *defining*
127    /// entity and a potential *dependency*.  For every predicate `p` it:
128    ///
129    /// 1. Adds `Predicate(p.name)` as a node.
130    /// 2. Creates a synthetic `Rule("<p>_rule")` that defines `p`.
131    /// 3. Adds a `Defines` edge from the rule node to the predicate node.
132    /// 4. For each argument domain `d` of `p`: adds `Predicate(d)` and a
133    ///    `Positive` edge from the rule to that domain predicate (modelling
134    ///    that evaluating `p` requires its domain to be populated).
135    pub fn from_symbol_table(table: &SymbolTable) -> Self {
136        let mut graph = RuleDependencyGraph::new();
137
138        for (pred_name, pred_info) in &table.predicates {
139            let pred_node = DepNode::Predicate(pred_name.clone());
140            let rule_node = DepNode::Rule(format!("{pred_name}_rule"));
141
142            graph.add_edge(rule_node.clone(), pred_node, DepEdge::Defines);
143
144            for domain_name in &pred_info.arg_domains {
145                let domain_node = DepNode::Predicate(domain_name.clone());
146                graph.add_edge(rule_node.clone(), domain_node, DepEdge::Positive);
147            }
148        }
149
150        graph
151    }
152
153    // ── Accessors ─────────────────────────────────────────────────────────────
154
155    /// All nodes in the graph.
156    pub fn nodes(&self) -> &HashSet<DepNode> {
157        &self.nodes
158    }
159
160    /// Nodes that `node` has outgoing edges to (successors / dependencies).
161    pub fn successors(&self, node: &DepNode) -> Vec<&DepNode> {
162        self.edges
163            .get(node)
164            .map(|v| v.iter().map(|(n, _)| n).collect())
165            .unwrap_or_default()
166    }
167
168    /// Nodes that have outgoing edges pointing to `node` (predecessors).
169    pub fn predecessors(&self, node: &DepNode) -> Vec<&DepNode> {
170        self.nodes
171            .iter()
172            .filter(|n| {
173                self.edges
174                    .get(n)
175                    .map(|v| v.iter().any(|(t, _)| t == node))
176                    .unwrap_or(false)
177            })
178            .collect()
179    }
180
181    /// Total number of nodes.
182    pub fn node_count(&self) -> usize {
183        self.nodes.len()
184    }
185
186    /// Total number of directed edges.
187    pub fn edge_count(&self) -> usize {
188        self.edges.values().map(|v| v.len()).sum()
189    }
190
191    // ── Cycle detection ───────────────────────────────────────────────────────
192
193    /// Returns `true` if the graph contains at least one directed cycle.
194    pub fn has_cycle(&self) -> bool {
195        let mut visited: HashSet<&DepNode> = HashSet::new();
196        let mut in_stack: HashSet<&DepNode> = HashSet::new();
197
198        for node in &self.nodes {
199            if !visited.contains(node) && self.dfs_has_cycle(node, &mut visited, &mut in_stack) {
200                return true;
201            }
202        }
203        false
204    }
205
206    fn dfs_has_cycle<'a>(
207        &'a self,
208        node: &'a DepNode,
209        visited: &mut HashSet<&'a DepNode>,
210        in_stack: &mut HashSet<&'a DepNode>,
211    ) -> bool {
212        visited.insert(node);
213        in_stack.insert(node);
214
215        if let Some(neighbours) = self.edges.get(node) {
216            for (next, _) in neighbours {
217                if !visited.contains(next) {
218                    if self.dfs_has_cycle(next, visited, in_stack) {
219                        return true;
220                    }
221                } else if in_stack.contains(next) {
222                    return true;
223                }
224            }
225        }
226
227        in_stack.remove(node);
228        false
229    }
230
231    /// Return the set of nodes that participate in *any* cycle.
232    pub fn find_cycle_nodes(&self) -> HashSet<&DepNode> {
233        // A node participates in a cycle iff it belongs to an SCC of size > 1
234        // OR has a self-loop.
235        let sccs = self.strongly_connected_components();
236        let mut result: HashSet<&DepNode> = HashSet::new();
237
238        for scc in &sccs {
239            if scc.len() > 1 {
240                for node in scc {
241                    if let Some(n) = self.nodes.get(node) {
242                        result.insert(n);
243                    }
244                }
245            } else if scc.len() == 1 {
246                // Check self-loop
247                let node = &scc[0];
248                if let Some(neighbours) = self.edges.get(node) {
249                    if neighbours.iter().any(|(t, _)| t == node) {
250                        if let Some(n) = self.nodes.get(node) {
251                            result.insert(n);
252                        }
253                    }
254                }
255            }
256        }
257
258        result
259    }
260
261    // ── Transitive dependencies ────────────────────────────────────────────────
262
263    /// Compute all nodes reachable from `node` via BFS (all edge types).
264    /// The starting node itself is *not* included in the result.
265    pub fn transitive_deps(&self, node: &DepNode) -> HashSet<DepNode> {
266        let mut visited: HashSet<DepNode> = HashSet::new();
267        let mut queue: VecDeque<DepNode> = VecDeque::new();
268
269        // Seed the queue with direct successors.
270        if let Some(neighbours) = self.edges.get(node) {
271            for (next, _) in neighbours {
272                if !visited.contains(next) {
273                    visited.insert(next.clone());
274                    queue.push_back(next.clone());
275                }
276            }
277        }
278
279        while let Some(current) = queue.pop_front() {
280            if let Some(neighbours) = self.edges.get(&current) {
281                for (next, _) in neighbours {
282                    if !visited.contains(next) {
283                        visited.insert(next.clone());
284                        queue.push_back(next.clone());
285                    }
286                }
287            }
288        }
289
290        visited
291    }
292
293    // ── Strongly Connected Components (Kosaraju's algorithm) ──────────────────
294
295    /// Compute all strongly connected components.
296    /// Each SCC is returned as a `Vec<DepNode>`; SCCs are in reverse topological
297    /// order (i.e. the first SCC has no outgoing edges to later SCCs).
298    pub fn strongly_connected_components(&self) -> Vec<Vec<DepNode>> {
299        // ── Pass 1: DFS on original graph, record finish order ────────────────
300        let mut visited: HashSet<&DepNode> = HashSet::new();
301        let mut finish_stack: Vec<&DepNode> = Vec::new();
302
303        for node in &self.nodes {
304            if !visited.contains(node) {
305                self.kosaraju_dfs_forward(node, &mut visited, &mut finish_stack);
306            }
307        }
308
309        // ── Build transposed graph ────────────────────────────────────────────
310        let transposed = self.transpose();
311
312        // ── Pass 2: DFS on transposed graph in reverse finish order ───────────
313        let mut visited2: HashSet<DepNode> = HashSet::new();
314        let mut sccs: Vec<Vec<DepNode>> = Vec::new();
315
316        for node in finish_stack.into_iter().rev() {
317            if !visited2.contains(node) {
318                let mut component: Vec<DepNode> = Vec::new();
319                Self::kosaraju_dfs_backward(node, &transposed, &mut visited2, &mut component);
320                sccs.push(component);
321            }
322        }
323
324        sccs
325    }
326
327    fn kosaraju_dfs_forward<'a>(
328        &'a self,
329        node: &'a DepNode,
330        visited: &mut HashSet<&'a DepNode>,
331        finish_stack: &mut Vec<&'a DepNode>,
332    ) {
333        visited.insert(node);
334        if let Some(neighbours) = self.edges.get(node) {
335            for (next, _) in neighbours {
336                if !visited.contains(next) {
337                    self.kosaraju_dfs_forward(next, visited, finish_stack);
338                }
339            }
340        }
341        finish_stack.push(node);
342    }
343
344    fn kosaraju_dfs_backward(
345        node: &DepNode,
346        transposed: &HashMap<DepNode, Vec<DepNode>>,
347        visited: &mut HashSet<DepNode>,
348        component: &mut Vec<DepNode>,
349    ) {
350        visited.insert(node.clone());
351        component.push(node.clone());
352
353        if let Some(neighbours) = transposed.get(node) {
354            for next in neighbours {
355                if !visited.contains(next) {
356                    Self::kosaraju_dfs_backward(next, transposed, visited, component);
357                }
358            }
359        }
360    }
361
362    /// Build the transpose (reverse) of this graph.
363    fn transpose(&self) -> HashMap<DepNode, Vec<DepNode>> {
364        let mut trans: HashMap<DepNode, Vec<DepNode>> = HashMap::new();
365
366        // Ensure every node appears (even without incoming edges).
367        for node in &self.nodes {
368            trans.entry(node.clone()).or_default();
369        }
370
371        for (from, neighbours) in &self.edges {
372            for (to, _) in neighbours {
373                trans.entry(to.clone()).or_default().push(from.clone());
374            }
375        }
376
377        trans
378    }
379
380    // ── Stratification ────────────────────────────────────────────────────────
381
382    /// Compute Datalog stratification layers.
383    ///
384    /// Returns `Ok(layers)` where layers are sorted by stratum index, or
385    /// `Err(StratificationError::NegativeCycle{..})` when the graph is
386    /// unstratifiable.
387    pub fn stratify(&self) -> Result<Vec<StratificationLayer>, StratificationError> {
388        // Assign every node an integer stratum starting at 0.
389        let mut stratum: HashMap<DepNode, usize> =
390            self.nodes.iter().map(|n| (n.clone(), 0_usize)).collect();
391
392        // Iterative fixed-point propagation.
393        let max_iters = self.nodes.len().saturating_add(1);
394        let mut changed = true;
395        let mut iter = 0_usize;
396
397        while changed && iter < max_iters {
398            changed = false;
399            iter = iter.saturating_add(1);
400
401            for (from, neighbours) in &self.edges {
402                let s_from = *stratum.get(from).unwrap_or(&0);
403                for (to, edge_kind) in neighbours {
404                    let min_stratum = match edge_kind {
405                        DepEdge::Positive | DepEdge::Defines => s_from,
406                        DepEdge::Negative => s_from.saturating_add(1),
407                    };
408                    let current = stratum.entry(to.clone()).or_insert(0);
409                    if min_stratum > *current {
410                        *current = min_stratum;
411                        changed = true;
412                    }
413                }
414            }
415        }
416
417        // Detect negative cycles: a negative edge (u→v) where stratum[u] >=
418        // stratum[v] after convergence indicates an unstratifiable graph.
419        let mut cycle_nodes: Vec<String> = Vec::new();
420        for (from, neighbours) in &self.edges {
421            let s_from = *stratum.get(from).unwrap_or(&0);
422            for (to, edge_kind) in neighbours {
423                if *edge_kind == DepEdge::Negative {
424                    let s_to = *stratum.get(to).unwrap_or(&0);
425                    if s_from >= s_to {
426                        cycle_nodes.push(from.name().to_owned());
427                        cycle_nodes.push(to.name().to_owned());
428                    }
429                }
430            }
431        }
432
433        if !cycle_nodes.is_empty() {
434            cycle_nodes.sort();
435            cycle_nodes.dedup();
436            return Err(StratificationError::NegativeCycle {
437                participating_nodes: cycle_nodes,
438            });
439        }
440
441        // Group nodes by stratum.
442        let mut layers_map: HashMap<usize, Vec<DepNode>> = HashMap::new();
443        for (node, s) in &stratum {
444            layers_map.entry(*s).or_default().push(node.clone());
445        }
446
447        // Determine which strata have at least one negative incoming edge.
448        let mut negative_strata: HashSet<usize> = HashSet::new();
449        for (from, neighbours) in &self.edges {
450            let s_from = *stratum.get(from).unwrap_or(&0);
451            for (to, edge_kind) in neighbours {
452                if *edge_kind == DepEdge::Negative {
453                    let s_to = *stratum.get(to).unwrap_or(&0);
454                    // The target stratum is strictly higher due to the +1 rule.
455                    if s_to > s_from {
456                        negative_strata.insert(s_to);
457                    }
458                }
459            }
460        }
461
462        let mut sorted_strata: Vec<usize> = layers_map.keys().copied().collect();
463        sorted_strata.sort_unstable();
464
465        let layers: Vec<StratificationLayer> = sorted_strata
466            .into_iter()
467            .map(|s| {
468                let mut nodes = layers_map.remove(&s).unwrap_or_default();
469                nodes.sort();
470                StratificationLayer {
471                    stratum: s,
472                    nodes,
473                    has_negation: negative_strata.contains(&s),
474                }
475            })
476            .collect();
477
478        Ok(layers)
479    }
480
481    // ── Rendering ─────────────────────────────────────────────────────────────
482
483    /// Render as a human-readable ASCII adjacency list (for debugging).
484    pub fn to_ascii(&self) -> String {
485        let mut buf = String::new();
486        let mut sorted_nodes: Vec<&DepNode> = self.nodes.iter().collect();
487        sorted_nodes.sort();
488
489        for node in sorted_nodes {
490            buf.push_str(&format!("{node}"));
491            let mut succs: Vec<String> = self
492                .edges
493                .get(node)
494                .map(|v| v.iter().map(|(n, e)| format!("  →{n}[{e}]")).collect())
495                .unwrap_or_default();
496            succs.sort();
497
498            if succs.is_empty() {
499                buf.push_str(" (leaf)\n");
500            } else {
501                buf.push('\n');
502                for s in succs {
503                    buf.push_str(&s);
504                    buf.push('\n');
505                }
506            }
507        }
508
509        buf
510    }
511
512    /// Render as Graphviz DOT format.
513    pub fn to_dot(&self) -> String {
514        let mut buf = String::from("digraph rule_deps {\n    rankdir=LR;\n");
515
516        // Node declarations with shape hints.
517        let mut sorted_nodes: Vec<&DepNode> = self.nodes.iter().collect();
518        sorted_nodes.sort();
519
520        for node in &sorted_nodes {
521            let (shape, label) = match node {
522                DepNode::Rule(n) => ("box", format!("Rule\\n{n}")),
523                DepNode::Predicate(n) => ("ellipse", format!("Pred\\n{n}")),
524            };
525            let id = dot_id(node);
526            buf.push_str(&format!("    {id} [label=\"{label}\" shape={shape}];\n"));
527        }
528
529        // Edge declarations.
530        for from in &sorted_nodes {
531            if let Some(neighbours) = self.edges.get(from) {
532                let mut sorted_neighbours: Vec<&(DepNode, DepEdge)> = neighbours.iter().collect();
533                sorted_neighbours.sort_by_key(|(n, _)| n);
534
535                for (to, edge_kind) in sorted_neighbours {
536                    let from_id = dot_id(from);
537                    let to_id = dot_id(to);
538                    let (style, label) = match edge_kind {
539                        DepEdge::Positive => ("solid", "pos"),
540                        DepEdge::Negative => ("dashed", "neg"),
541                        DepEdge::Defines => ("bold", "def"),
542                    };
543                    buf.push_str(&format!(
544                        "    {from_id} -> {to_id} [label=\"{label}\" style={style}];\n"
545                    ));
546                }
547            }
548        }
549
550        buf.push('}');
551        buf
552    }
553}
554
555/// Sanitise a node name to a valid DOT identifier.
556fn dot_id(node: &DepNode) -> String {
557    let prefix = if node.is_rule() { "r_" } else { "p_" };
558    let name: String = node
559        .name()
560        .chars()
561        .map(|c| {
562            if c.is_alphanumeric() || c == '_' {
563                c
564            } else {
565                '_'
566            }
567        })
568        .collect();
569    format!("{prefix}{name}")
570}
571
572// ─────────────────────────────────────────────────────────────────────────────
573// StratificationLayer
574// ─────────────────────────────────────────────────────────────────────────────
575
576/// A set of nodes that can be evaluated at the same stratum.
577#[derive(Debug, Clone)]
578pub struct StratificationLayer {
579    /// Zero-based stratum index (lower = evaluated first).
580    pub stratum: usize,
581    /// All nodes at this stratum (sorted for determinism).
582    pub nodes: Vec<DepNode>,
583    /// `true` when at least one incoming edge to this stratum is `Negative`.
584    pub has_negation: bool,
585}
586
587// ─────────────────────────────────────────────────────────────────────────────
588// StratificationError
589// ─────────────────────────────────────────────────────────────────────────────
590
591/// Errors produced by the stratification algorithm.
592#[derive(Debug, Clone)]
593pub enum StratificationError {
594    /// The graph contains a cycle involving at least one negative edge.
595    NegativeCycle {
596        /// Names of the nodes that participate in the negative cycle.
597        participating_nodes: Vec<String>,
598    },
599    /// General stratification failure with a descriptive message.
600    UnstratifiableGraph(String),
601}
602
603impl std::fmt::Display for StratificationError {
604    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605        match self {
606            StratificationError::NegativeCycle {
607                participating_nodes,
608            } => {
609                write!(
610                    f,
611                    "Negative cycle detected involving nodes: [{}]",
612                    participating_nodes.join(", ")
613                )
614            }
615            StratificationError::UnstratifiableGraph(msg) => {
616                write!(f, "Unstratifiable graph: {msg}")
617            }
618        }
619    }
620}
621
622impl std::error::Error for StratificationError {}
623
624// ─────────────────────────────────────────────────────────────────────────────
625// DepGraphStats
626// ─────────────────────────────────────────────────────────────────────────────
627
628/// Summary statistics for a `RuleDependencyGraph`.
629#[derive(Debug, Clone)]
630pub struct DepGraphStats {
631    /// Total node count.
632    pub num_nodes: usize,
633    /// Total edge count.
634    pub num_edges: usize,
635    /// Number of rule nodes.
636    pub num_rules: usize,
637    /// Number of predicate nodes.
638    pub num_predicates: usize,
639    /// Whether the graph contains any directed cycle.
640    pub has_cycles: bool,
641    /// Number of strongly connected components.
642    pub num_sccs: usize,
643    /// Size of the largest SCC.
644    pub max_scc_size: usize,
645    /// Number of strata (`None` when the graph is not stratifiable).
646    pub num_strata: Option<usize>,
647    /// Length of the longest chain of dependencies (BFS diameter from any node).
648    pub longest_dependency_chain: usize,
649}
650
651impl DepGraphStats {
652    /// Compute statistics for the given graph.
653    pub fn compute(graph: &RuleDependencyGraph) -> Self {
654        let num_nodes = graph.node_count();
655        let num_edges = graph.edge_count();
656        let num_rules = graph.nodes.iter().filter(|n| n.is_rule()).count();
657        let num_predicates = graph.nodes.iter().filter(|n| n.is_predicate()).count();
658        let has_cycles = graph.has_cycle();
659
660        let sccs = graph.strongly_connected_components();
661        let num_sccs = sccs.len();
662        let max_scc_size = sccs.iter().map(|s| s.len()).max().unwrap_or(0);
663
664        let num_strata = match graph.stratify() {
665            Ok(layers) => Some(layers.len()),
666            Err(_) => None,
667        };
668
669        let longest_dependency_chain = compute_longest_chain(graph);
670
671        DepGraphStats {
672            num_nodes,
673            num_edges,
674            num_rules,
675            num_predicates,
676            has_cycles,
677            num_sccs,
678            max_scc_size,
679            num_strata,
680            longest_dependency_chain,
681        }
682    }
683}
684
685/// BFS-based longest chain length across all starting nodes.
686fn compute_longest_chain(graph: &RuleDependencyGraph) -> usize {
687    let mut max_len = 0_usize;
688
689    for start in &graph.nodes {
690        let mut dist: HashMap<&DepNode, usize> = HashMap::new();
691        let mut queue: VecDeque<&DepNode> = VecDeque::new();
692        dist.insert(start, 0);
693        queue.push_back(start);
694
695        while let Some(cur) = queue.pop_front() {
696            let cur_dist = *dist.get(cur).unwrap_or(&0);
697            if let Some(neighbours) = graph.edges.get(cur) {
698                for (next, _) in neighbours {
699                    if !dist.contains_key(next) {
700                        dist.insert(next, cur_dist + 1);
701                        queue.push_back(next);
702                        if cur_dist + 1 > max_len {
703                            max_len = cur_dist + 1;
704                        }
705                    }
706                }
707            }
708        }
709    }
710
711    max_len
712}
713
714// ─────────────────────────────────────────────────────────────────────────────
715// Tests
716// ─────────────────────────────────────────────────────────────────────────────
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    // ── helpers ────────────────────────────────────────────────────────────────
723
724    fn rule(n: &str) -> DepNode {
725        DepNode::Rule(n.to_owned())
726    }
727
728    fn pred(n: &str) -> DepNode {
729        DepNode::Predicate(n.to_owned())
730    }
731
732    // ── DepNode ────────────────────────────────────────────────────────────────
733
734    #[test]
735    fn test_dep_node_name() {
736        assert_eq!(rule("foo").name(), "foo");
737        assert_eq!(pred("bar").name(), "bar");
738    }
739
740    #[test]
741    fn test_dep_node_is_rule_predicate() {
742        let r = rule("r1");
743        let p = pred("p1");
744        assert!(r.is_rule());
745        assert!(!r.is_predicate());
746        assert!(p.is_predicate());
747        assert!(!p.is_rule());
748    }
749
750    // ── Graph construction ────────────────────────────────────────────────────
751
752    #[test]
753    fn test_add_node_and_edge() {
754        let mut g = RuleDependencyGraph::new();
755        g.add_node(rule("r1"));
756        g.add_node(pred("p1"));
757        assert_eq!(g.node_count(), 2);
758        assert_eq!(g.edge_count(), 0);
759
760        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
761        assert_eq!(g.edge_count(), 1);
762        // add_edge should not duplicate nodes
763        assert_eq!(g.node_count(), 2);
764    }
765
766    #[test]
767    fn test_successors() {
768        let mut g = RuleDependencyGraph::new();
769        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
770        g.add_edge(rule("r1"), pred("p2"), DepEdge::Positive);
771
772        let mut succs: Vec<&DepNode> = g.successors(&rule("r1"));
773        succs.sort();
774        assert_eq!(succs.len(), 2);
775        assert!(succs.contains(&&pred("p1")));
776        assert!(succs.contains(&&pred("p2")));
777    }
778
779    #[test]
780    fn test_predecessors() {
781        let mut g = RuleDependencyGraph::new();
782        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
783        g.add_edge(rule("r2"), pred("p1"), DepEdge::Positive);
784
785        let preds = g.predecessors(&pred("p1"));
786        assert_eq!(preds.len(), 2);
787        assert!(preds.contains(&&rule("r1")));
788        assert!(preds.contains(&&rule("r2")));
789    }
790
791    // ── Cycle detection ───────────────────────────────────────────────────────
792
793    #[test]
794    fn test_has_cycle_false() {
795        let mut g = RuleDependencyGraph::new();
796        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
797        g.add_edge(pred("p1"), pred("p2"), DepEdge::Positive);
798        assert!(!g.has_cycle());
799    }
800
801    #[test]
802    fn test_has_cycle_true() {
803        let mut g = RuleDependencyGraph::new();
804        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
805        g.add_edge(pred("b"), pred("a"), DepEdge::Positive);
806        assert!(g.has_cycle());
807    }
808
809    #[test]
810    fn test_find_cycle_nodes() {
811        let mut g = RuleDependencyGraph::new();
812        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
813        g.add_edge(pred("b"), pred("a"), DepEdge::Positive);
814        // pred("c") is outside the cycle
815        g.add_edge(pred("c"), pred("a"), DepEdge::Positive);
816
817        let cycle_nodes = g.find_cycle_nodes();
818        assert!(cycle_nodes.contains(&pred("a")));
819        assert!(cycle_nodes.contains(&pred("b")));
820        assert!(!cycle_nodes.contains(&pred("c")));
821    }
822
823    // ── Transitive deps ───────────────────────────────────────────────────────
824
825    #[test]
826    fn test_transitive_deps_simple() {
827        let mut g = RuleDependencyGraph::new();
828        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
829        g.add_edge(pred("b"), pred("c"), DepEdge::Positive);
830
831        let deps = g.transitive_deps(&pred("a"));
832        assert!(deps.contains(&pred("b")));
833        assert!(deps.contains(&pred("c")));
834        assert!(!deps.contains(&pred("a")));
835    }
836
837    #[test]
838    fn test_transitive_deps_empty() {
839        let mut g = RuleDependencyGraph::new();
840        g.add_node(pred("leaf"));
841
842        let deps = g.transitive_deps(&pred("leaf"));
843        assert!(deps.is_empty());
844    }
845
846    // ── SCCs ──────────────────────────────────────────────────────────────────
847
848    #[test]
849    fn test_scc_single_node() {
850        let mut g = RuleDependencyGraph::new();
851        g.add_node(pred("p1"));
852
853        let sccs = g.strongly_connected_components();
854        assert_eq!(sccs.len(), 1);
855        assert_eq!(sccs[0].len(), 1);
856    }
857
858    #[test]
859    fn test_scc_cycle() {
860        let mut g = RuleDependencyGraph::new();
861        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
862        g.add_edge(pred("b"), pred("a"), DepEdge::Positive);
863
864        let sccs = g.strongly_connected_components();
865        // Should find exactly one SCC of size 2.
866        let big: Vec<_> = sccs.iter().filter(|s| s.len() == 2).collect();
867        assert_eq!(big.len(), 1);
868        let scc = &big[0];
869        assert!(scc.contains(&pred("a")));
870        assert!(scc.contains(&pred("b")));
871    }
872
873    #[test]
874    fn test_scc_dag() {
875        let mut g = RuleDependencyGraph::new();
876        // Pure DAG: a→b→c, no back-edges.
877        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
878        g.add_edge(pred("b"), pred("c"), DepEdge::Positive);
879
880        let sccs = g.strongly_connected_components();
881        // Every node is its own SCC.
882        assert_eq!(sccs.len(), 3);
883        assert!(sccs.iter().all(|s| s.len() == 1));
884    }
885
886    // ── Stratification ────────────────────────────────────────────────────────
887
888    #[test]
889    fn test_stratify_simple_dag() {
890        let mut g = RuleDependencyGraph::new();
891        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
892        g.add_edge(pred("b"), pred("c"), DepEdge::Positive);
893
894        let layers = g.stratify().expect("should stratify");
895        // a, b, c must each be at stratum 0 because all edges are Positive
896        // (stratum[v] = max(stratum[v], stratum[u]) — same stratum is fine).
897        // The exact assignment: all at 0.
898        let get_stratum = |name: &str| -> usize {
899            layers
900                .iter()
901                .find(|l| l.nodes.contains(&pred(name)))
902                .map(|l| l.stratum)
903                .expect("node present")
904        };
905        // With only Positive edges the fixed-point keeps all at 0.
906        assert_eq!(get_stratum("a"), 0);
907        assert_eq!(get_stratum("b"), 0);
908        assert_eq!(get_stratum("c"), 0);
909    }
910
911    #[test]
912    fn test_stratify_with_negation() {
913        let mut g = RuleDependencyGraph::new();
914        // a -neg→ b: b must be at a higher stratum than a.
915        g.add_edge(pred("a"), pred("b"), DepEdge::Negative);
916
917        let layers = g.stratify().expect("should stratify");
918        let stratum_a = layers
919            .iter()
920            .find(|l| l.nodes.contains(&pred("a")))
921            .map(|l| l.stratum)
922            .expect("a present");
923        let stratum_b = layers
924            .iter()
925            .find(|l| l.nodes.contains(&pred("b")))
926            .map(|l| l.stratum)
927            .expect("b present");
928        assert!(stratum_b > stratum_a);
929    }
930
931    #[test]
932    fn test_stratify_negative_cycle_error() {
933        let mut g = RuleDependencyGraph::new();
934        // A -neg→ B -neg→ A  ⇒ unstratifiable.
935        g.add_edge(pred("a"), pred("b"), DepEdge::Negative);
936        g.add_edge(pred("b"), pred("a"), DepEdge::Negative);
937
938        let result = g.stratify();
939        assert!(
940            matches!(result, Err(StratificationError::NegativeCycle { .. })),
941            "expected NegativeCycle, got: {result:?}"
942        );
943    }
944
945    // ── Stats ──────────────────────────────────────────────────────────────────
946
947    #[test]
948    fn test_dep_graph_stats_basic() {
949        let mut g = RuleDependencyGraph::new();
950        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
951        g.add_edge(rule("r1"), pred("p2"), DepEdge::Positive);
952
953        let stats = DepGraphStats::compute(&g);
954        assert_eq!(stats.num_nodes, 3);
955        assert_eq!(stats.num_edges, 2);
956        assert_eq!(stats.num_rules, 1);
957        assert_eq!(stats.num_predicates, 2);
958    }
959
960    #[test]
961    fn test_dep_graph_stats_has_cycles() {
962        let mut g = RuleDependencyGraph::new();
963        g.add_edge(pred("a"), pred("b"), DepEdge::Positive);
964        g.add_edge(pred("b"), pred("a"), DepEdge::Positive);
965
966        let stats = DepGraphStats::compute(&g);
967        assert!(stats.has_cycles);
968    }
969
970    // ── Rendering ─────────────────────────────────────────────────────────────
971
972    #[test]
973    fn test_to_ascii_nonempty() {
974        let mut g = RuleDependencyGraph::new();
975        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
976
977        let ascii = g.to_ascii();
978        assert!(!ascii.is_empty());
979    }
980
981    #[test]
982    fn test_to_dot_contains_digraph() {
983        let mut g = RuleDependencyGraph::new();
984        g.add_edge(rule("r1"), pred("p1"), DepEdge::Defines);
985
986        let dot = g.to_dot();
987        assert!(dot.contains("digraph"));
988    }
989}