Skip to main content

tensorlogic_infer/
critical_path.rs

1//! Critical path analysis for inference computation graphs.
2//!
3//! Computes the longest dependency chain in a DAG (the critical path) using:
4//! 1. Kahn's algorithm for topological sort.
5//! 2. DP over the topological order: `dist[v] = max_over_predecessors(dist[u] + cost(v))`.
6//!
7//! The result includes the path as a sequence of [`NodeId`]s, the total
8//! accumulated latency in nanoseconds, and the single bottleneck node
9//! (the one that sits at the end of the longest path).
10//!
11//! Nodes whose `latency_ns` is `None` are treated as 1 ns and a
12//! [`MissingCostWarning`] is emitted for each one.
13//!
14//! Cycle detection: if Kahn's algorithm cannot drain all nodes a
15//! [`CriticalPathError::CycleDetected`] is returned instead of panicking.
16//!
17//! # Example
18//!
19//! ```rust
20//! use tensorlogic_infer::critical_path::{
21//!     InferenceGraph, NodeLatency, critical_path,
22//! };
23//!
24//! let mut g = InferenceGraph::default();
25//! let a = g.add_node(NodeLatency { latency_ns: Some(10) });
26//! let b = g.add_node(NodeLatency { latency_ns: Some(20) });
27//! let c = g.add_node(NodeLatency { latency_ns: Some(5) });
28//! g.add_edge(a, b).unwrap();
29//! g.add_edge(b, c).unwrap();
30//!
31//! let result = critical_path(&g).unwrap();
32//! assert_eq!(result.report.nodes, vec![a, b, c]);
33//! assert_eq!(result.report.total_latency_ns, 35);
34//! ```
35
36use std::collections::VecDeque;
37use thiserror::Error;
38
39// ─────────────────────────────────────────────────────────────────────────────
40// Public types
41// ─────────────────────────────────────────────────────────────────────────────
42
43/// Unique identifier for a node in an [`InferenceGraph`].
44pub type NodeId = usize;
45
46/// Per-node latency annotation.
47///
48/// When `latency_ns` is `None` the analysis falls back to 1 ns and emits a
49/// [`MissingCostWarning`].
50#[derive(Debug, Clone, Default)]
51pub struct NodeLatency {
52    /// Estimated execution latency in nanoseconds for this node.
53    pub latency_ns: Option<u64>,
54}
55
56impl NodeLatency {
57    /// Convenience constructor.
58    pub fn new(latency_ns: u64) -> Self {
59        Self {
60            latency_ns: Some(latency_ns),
61        }
62    }
63}
64
65/// A lightweight Directed Acyclic Graph (DAG) of inference nodes.
66///
67/// Nodes are added in order and receive consecutive [`NodeId`]s starting from
68/// zero.  Edges encode data-flow dependencies: edge `(from, to)` means "node
69/// `from` must execute before node `to`".
70#[derive(Debug, Clone, Default)]
71pub struct InferenceGraph {
72    /// Per-node latency annotations; index == [`NodeId`].
73    pub nodes: Vec<NodeLatency>,
74    /// Directed edges `(from, to)` — i.e. `from` → `to`.
75    pub edges: Vec<(NodeId, NodeId)>,
76}
77
78impl InferenceGraph {
79    /// Create an empty graph.
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Add a node with the given latency annotation and return its [`NodeId`].
85    pub fn add_node(&mut self, latency: NodeLatency) -> NodeId {
86        let id = self.nodes.len();
87        self.nodes.push(latency);
88        id
89    }
90
91    /// Add a directed edge `from → to`.
92    ///
93    /// Returns [`CriticalPathError::InvalidNode`] if either node index is out
94    /// of range.
95    pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<(), CriticalPathError> {
96        let n = self.nodes.len();
97        if from >= n {
98            return Err(CriticalPathError::InvalidNode(from));
99        }
100        if to >= n {
101            return Err(CriticalPathError::InvalidNode(to));
102        }
103        self.edges.push((from, to));
104        Ok(())
105    }
106
107    /// Number of nodes in the graph.
108    pub fn num_nodes(&self) -> usize {
109        self.nodes.len()
110    }
111
112    /// Number of edges in the graph.
113    pub fn num_edges(&self) -> usize {
114        self.edges.len()
115    }
116}
117
118// ─────────────────────────────────────────────────────────────────────────────
119// Result types
120// ─────────────────────────────────────────────────────────────────────────────
121
122/// The critical-path analysis report.
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct CriticalPathReport {
125    /// Ordered sequence of [`NodeId`]s on the critical path, from source to
126    /// sink (inclusive).  Empty when the graph has no nodes.
127    pub nodes: Vec<NodeId>,
128    /// Sum of node latencies along the critical path, in nanoseconds.
129    pub total_latency_ns: u64,
130    /// The node with the highest individual latency on the critical path.
131    /// Zero when `nodes` is empty.
132    pub bottleneck: NodeId,
133}
134
135/// Warning emitted when a node has no latency annotation.
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct MissingCostWarning {
138    /// The node whose `latency_ns` was `None`.
139    pub node_id: NodeId,
140}
141
142/// Combined result: analysis report plus any latency-annotation warnings.
143#[derive(Debug, Clone)]
144pub struct CriticalPathResult {
145    pub report: CriticalPathReport,
146    pub warnings: Vec<MissingCostWarning>,
147}
148
149/// Errors returned by [`critical_path`].
150#[derive(Error, Debug, Clone, PartialEq, Eq)]
151pub enum CriticalPathError {
152    /// The graph contains a directed cycle; critical-path analysis requires a
153    /// DAG.  The enclosed string names the nodes that could not be processed.
154    #[error("Cycle detected; these nodes were not reachable via topological sort: {0}")]
155    CycleDetected(String),
156
157    /// An edge references a node index that does not exist in the graph.
158    #[error("Edge references out-of-range node id {0}")]
159    InvalidNode(NodeId),
160}
161
162// ─────────────────────────────────────────────────────────────────────────────
163// Core algorithm
164// ─────────────────────────────────────────────────────────────────────────────
165
166/// Compute the critical path of `graph`.
167///
168/// Returns `Ok(CriticalPathResult)` for valid DAGs, or
169/// `Err(CriticalPathError::CycleDetected)` if the graph is cyclic.
170///
171/// # Algorithm
172///
173/// 1. Build forward adjacency list and reverse adjacency list (predecessor map)
174///    together with in-degree counts.
175/// 2. Run Kahn's BFS topological sort.  If any nodes are left unprocessed,
176///    a cycle exists.
177/// 3. DP over topo order: `dist[v] = cost(v) + max(dist[u] for u in pred(v))`.
178///    Track the predecessor that achieved the maximum for path reconstruction.
179/// 4. The node with the maximum `dist` value is the end of the critical path.
180/// 5. Walk predecessors back to reconstruct the path, then reverse it.
181pub fn critical_path(graph: &InferenceGraph) -> Result<CriticalPathResult, CriticalPathError> {
182    let n = graph.num_nodes();
183
184    // Empty graph — trivial result.
185    if n == 0 {
186        return Ok(CriticalPathResult {
187            report: CriticalPathReport {
188                nodes: vec![],
189                total_latency_ns: 0,
190                bottleneck: 0,
191            },
192            warnings: vec![],
193        });
194    }
195
196    // ── 1. Build adjacency structures ──────────────────────────────────────
197    //
198    // `succ[u]` = list of nodes that depend on u (successors).
199    // `pred[v]` = list of nodes that v depends on (predecessors).
200    // `in_degree[v]` = number of predecessors.
201
202    let mut succ: Vec<Vec<NodeId>> = vec![vec![]; n];
203    let mut pred: Vec<Vec<NodeId>> = vec![vec![]; n];
204    let mut in_degree: Vec<usize> = vec![0; n];
205
206    for &(from, to) in &graph.edges {
207        // Edge validity was checked at add_edge time, but edges could have
208        // been added to the raw field directly; guard anyway.
209        if from >= n || to >= n {
210            return Err(CriticalPathError::InvalidNode(if from >= n {
211                from
212            } else {
213                to
214            }));
215        }
216        succ[from].push(to);
217        pred[to].push(from);
218        in_degree[to] += 1;
219    }
220
221    // ── 2. Collect per-node costs; emit warnings for missing annotations ───
222
223    let mut warnings: Vec<MissingCostWarning> = vec![];
224    let costs: Vec<u64> = graph
225        .nodes
226        .iter()
227        .enumerate()
228        .map(|(id, nl)| {
229            nl.latency_ns.unwrap_or_else(|| {
230                warnings.push(MissingCostWarning { node_id: id });
231                1
232            })
233        })
234        .collect();
235
236    // ── 3. Kahn's BFS topological sort ─────────────────────────────────────
237
238    let mut queue: VecDeque<NodeId> = VecDeque::new();
239    for v in 0..n {
240        if in_degree[v] == 0 {
241            queue.push_back(v);
242        }
243    }
244
245    let mut topo_order: Vec<NodeId> = Vec::with_capacity(n);
246    // Work on a mutable copy of in-degrees so we can decrement during BFS.
247    let mut remaining_in: Vec<usize> = in_degree.clone();
248
249    while let Some(u) = queue.pop_front() {
250        topo_order.push(u);
251        for &v in &succ[u] {
252            remaining_in[v] -= 1;
253            if remaining_in[v] == 0 {
254                queue.push_back(v);
255            }
256        }
257    }
258
259    if topo_order.len() != n {
260        // Some nodes were not processed — there is a cycle.
261        let cyclic: Vec<String> = (0..n)
262            .filter(|&v| !topo_order.contains(&v))
263            .map(|v| v.to_string())
264            .collect();
265        return Err(CriticalPathError::CycleDetected(cyclic.join(", ")));
266    }
267
268    // ── 4. DP: longest path in topo order ──────────────────────────────────
269    //
270    // `dist[v]` = maximum accumulated latency of any path ending at `v`
271    //             (including v's own cost).
272    // `best_pred[v]` = the predecessor that achieved `dist[v]`, or `None` for
273    //                  source nodes.
274
275    let mut dist: Vec<u64> = vec![0; n];
276    let mut best_pred: Vec<Option<NodeId>> = vec![None; n];
277
278    for &v in &topo_order {
279        // Start with just this node's own cost.
280        dist[v] = costs[v];
281        best_pred[v] = None;
282
283        // Extend the longest predecessor path.
284        for &u in &pred[v] {
285            let candidate = dist[u].saturating_add(costs[v]);
286            if candidate > dist[v] {
287                dist[v] = candidate;
288                best_pred[v] = Some(u);
289            }
290        }
291    }
292
293    // ── 5. Find the sink with the maximum distance ─────────────────────────
294
295    let (end_node, &max_dist) = dist
296        .iter()
297        .enumerate()
298        .max_by_key(|&(_, d)| d)
299        .unwrap_or((0, &0)); // Safety: n > 0 so the iterator is non-empty.
300
301    // ── 6. Reconstruct the path by walking back through best_pred ──────────
302
303    let mut path: Vec<NodeId> = vec![];
304    let mut current = end_node;
305    loop {
306        path.push(current);
307        match best_pred[current] {
308            Some(prev) => current = prev,
309            None => break,
310        }
311    }
312    path.reverse();
313
314    // ── 7. Identify bottleneck: node on path with highest individual cost ──
315
316    let bottleneck = path
317        .iter()
318        .copied()
319        .max_by_key(|&v| costs[v])
320        .unwrap_or(end_node);
321
322    Ok(CriticalPathResult {
323        report: CriticalPathReport {
324            nodes: path,
325            total_latency_ns: max_dist,
326            bottleneck,
327        },
328        warnings,
329    })
330}
331
332// ─────────────────────────────────────────────────────────────────────────────
333// Tests
334// ─────────────────────────────────────────────────────────────────────────────
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    // Helper: build a graph from a node-latency list and an edge list.
341    fn build_graph(latencies: &[Option<u64>], edges: &[(usize, usize)]) -> InferenceGraph {
342        let mut g = InferenceGraph::new();
343        for &lat in latencies {
344            g.add_node(NodeLatency { latency_ns: lat });
345        }
346        for &(from, to) in edges {
347            g.add_edge(from, to).expect("valid edge");
348        }
349        g
350    }
351
352    // ── Test 1: linear chain A→B→C ────────────────────────────────────────
353
354    #[test]
355    fn test_linear_chain() {
356        // A(10) → B(20) → C(5)  — only one path, total = 35
357        let g = build_graph(&[Some(10), Some(20), Some(5)], &[(0, 1), (1, 2)]);
358        let res = critical_path(&g).expect("no cycle");
359
360        assert_eq!(res.report.nodes, vec![0, 1, 2]);
361        assert_eq!(res.report.total_latency_ns, 35);
362        assert_eq!(res.report.bottleneck, 1); // node 1 has cost 20
363        assert!(res.warnings.is_empty());
364    }
365
366    // ── Test 2: diamond — longer branch wins ─────────────────────────────
367
368    #[test]
369    fn test_diamond_longer_branch_wins() {
370        // A(1) → B(100) → D(1)
371        //      → C(1)   → D
372        // Critical path: A→B→D, total = 102
373        let g = build_graph(
374            &[Some(1), Some(100), Some(1), Some(1)],
375            &[(0, 1), (0, 2), (1, 3), (2, 3)],
376        );
377        let res = critical_path(&g).expect("no cycle");
378
379        assert_eq!(res.report.nodes, vec![0, 1, 3]);
380        assert_eq!(res.report.total_latency_ns, 102);
381        assert_eq!(res.report.bottleneck, 1); // node 1 has cost 100
382        assert!(res.warnings.is_empty());
383    }
384
385    // ── Test 3: single node ───────────────────────────────────────────────
386
387    #[test]
388    fn test_single_node() {
389        let g = build_graph(&[Some(42)], &[]);
390        let res = critical_path(&g).expect("no cycle");
391
392        assert_eq!(res.report.nodes, vec![0]);
393        assert_eq!(res.report.total_latency_ns, 42);
394        assert_eq!(res.report.bottleneck, 0);
395        assert!(res.warnings.is_empty());
396    }
397
398    // ── Test 4: missing latency annotations emit warnings ─────────────────
399
400    #[test]
401    fn test_missing_latency_warning() {
402        // A(None) → B(None) → C(None)
403        // Each falls back to 1 ns → total = 3, warnings for all three.
404        let g = build_graph(&[None, None, None], &[(0, 1), (1, 2)]);
405        let res = critical_path(&g).expect("no cycle");
406
407        assert_eq!(res.report.total_latency_ns, 3);
408        assert_eq!(res.warnings.len(), 3);
409        let warned_ids: Vec<NodeId> = res.warnings.iter().map(|w| w.node_id).collect();
410        assert!(warned_ids.contains(&0));
411        assert!(warned_ids.contains(&1));
412        assert!(warned_ids.contains(&2));
413    }
414
415    // ── Test 5: empty graph ───────────────────────────────────────────────
416
417    #[test]
418    fn test_empty_graph() {
419        let g = InferenceGraph::new();
420        let res = critical_path(&g).expect("no cycle");
421
422        assert!(res.report.nodes.is_empty());
423        assert_eq!(res.report.total_latency_ns, 0);
424        assert_eq!(res.report.bottleneck, 0);
425        assert!(res.warnings.is_empty());
426    }
427
428    // ── Test 6: cycle detection returns an error ──────────────────────────
429
430    #[test]
431    fn test_cycle_detected() {
432        // A → B → C → A  (cycle)
433        let g = build_graph(&[Some(1), Some(1), Some(1)], &[(0, 1), (1, 2), (2, 0)]);
434        let err = critical_path(&g).expect_err("should detect cycle");
435        matches!(err, CriticalPathError::CycleDetected(_));
436    }
437
438    // ── Test 7: parallel branches without shared sink ─────────────────────
439
440    #[test]
441    fn test_parallel_branches() {
442        // Two independent chains: A(5)→B(10) and C(1)→D(3)
443        // Longest path ends at B with dist 15.
444        let g = build_graph(&[Some(5), Some(10), Some(1), Some(3)], &[(0, 1), (2, 3)]);
445        let res = critical_path(&g).expect("no cycle");
446
447        assert_eq!(res.report.total_latency_ns, 15);
448        assert_eq!(*res.report.nodes.last().expect("non-empty"), 1);
449    }
450
451    // ── Test 8: wide graph — fan-out then fan-in ──────────────────────────
452
453    #[test]
454    fn test_fan_out_fan_in() {
455        // root(1) → mid0(2) → sink(1)
456        //         → mid1(5) → sink
457        //         → mid2(3) → sink
458        // Longest: root→mid1→sink = 1+5+1 = 7
459        let g = build_graph(
460            &[Some(1), Some(2), Some(5), Some(3), Some(1)],
461            &[(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)],
462        );
463        let res = critical_path(&g).expect("no cycle");
464
465        assert_eq!(res.report.total_latency_ns, 7);
466        assert_eq!(res.report.nodes, vec![0, 2, 4]);
467        assert_eq!(res.report.bottleneck, 2); // cost 5
468    }
469
470    // ── Test 9: invalid edge returns error ────────────────────────────────
471
472    #[test]
473    fn test_invalid_edge() {
474        let mut g = InferenceGraph::new();
475        g.add_node(NodeLatency::new(10));
476        let err = g.add_edge(0, 5).expect_err("node 5 does not exist");
477        matches!(err, CriticalPathError::InvalidNode(5));
478    }
479
480    // ── Test 10: mixed latencies, partially annotated ─────────────────────
481
482    #[test]
483    fn test_mixed_latencies() {
484        // A(10) → B(None=1 fallback) → C(50)
485        // total = 10+1+50 = 61, 1 warning for B
486        let g = build_graph(&[Some(10), None, Some(50)], &[(0, 1), (1, 2)]);
487        let res = critical_path(&g).expect("no cycle");
488
489        assert_eq!(res.report.total_latency_ns, 61);
490        assert_eq!(res.report.bottleneck, 2); // C has cost 50
491        assert_eq!(res.warnings.len(), 1);
492        assert_eq!(res.warnings[0].node_id, 1);
493    }
494}