tensorlogic_infer/critical_path.rs
1//! Critical path analysis for inference computation graphs.
2//!
3//! Computes the longest dependency chain in a DAG (the critical path) using:
4//! 1. Kahn's algorithm for topological sort.
5//! 2. DP over the topological order: `dist[v] = max_over_predecessors(dist[u] + cost(v))`.
6//!
7//! The result includes the path as a sequence of [`NodeId`]s, the total
8//! accumulated latency in nanoseconds, and the single bottleneck node
9//! (the one that sits at the end of the longest path).
10//!
11//! Nodes whose `latency_ns` is `None` are treated as 1 ns and a
12//! [`MissingCostWarning`] is emitted for each one.
13//!
14//! Cycle detection: if Kahn's algorithm cannot drain all nodes a
15//! [`CriticalPathError::CycleDetected`] is returned instead of panicking.
16//!
17//! # Example
18//!
19//! ```rust
20//! use tensorlogic_infer::critical_path::{
21//! InferenceGraph, NodeLatency, critical_path,
22//! };
23//!
24//! let mut g = InferenceGraph::default();
25//! let a = g.add_node(NodeLatency { latency_ns: Some(10) });
26//! let b = g.add_node(NodeLatency { latency_ns: Some(20) });
27//! let c = g.add_node(NodeLatency { latency_ns: Some(5) });
28//! g.add_edge(a, b).unwrap();
29//! g.add_edge(b, c).unwrap();
30//!
31//! let result = critical_path(&g).unwrap();
32//! assert_eq!(result.report.nodes, vec![a, b, c]);
33//! assert_eq!(result.report.total_latency_ns, 35);
34//! ```
35
36use std::collections::VecDeque;
37use thiserror::Error;
38
39// ─────────────────────────────────────────────────────────────────────────────
40// Public types
41// ─────────────────────────────────────────────────────────────────────────────
42
43/// Unique identifier for a node in an [`InferenceGraph`].
44pub type NodeId = usize;
45
46/// Per-node latency annotation.
47///
48/// When `latency_ns` is `None` the analysis falls back to 1 ns and emits a
49/// [`MissingCostWarning`].
50#[derive(Debug, Clone, Default)]
51pub struct NodeLatency {
52 /// Estimated execution latency in nanoseconds for this node.
53 pub latency_ns: Option<u64>,
54}
55
56impl NodeLatency {
57 /// Convenience constructor.
58 pub fn new(latency_ns: u64) -> Self {
59 Self {
60 latency_ns: Some(latency_ns),
61 }
62 }
63}
64
65/// A lightweight Directed Acyclic Graph (DAG) of inference nodes.
66///
67/// Nodes are added in order and receive consecutive [`NodeId`]s starting from
68/// zero. Edges encode data-flow dependencies: edge `(from, to)` means "node
69/// `from` must execute before node `to`".
70#[derive(Debug, Clone, Default)]
71pub struct InferenceGraph {
72 /// Per-node latency annotations; index == [`NodeId`].
73 pub nodes: Vec<NodeLatency>,
74 /// Directed edges `(from, to)` — i.e. `from` → `to`.
75 pub edges: Vec<(NodeId, NodeId)>,
76}
77
78impl InferenceGraph {
79 /// Create an empty graph.
80 pub fn new() -> Self {
81 Self::default()
82 }
83
84 /// Add a node with the given latency annotation and return its [`NodeId`].
85 pub fn add_node(&mut self, latency: NodeLatency) -> NodeId {
86 let id = self.nodes.len();
87 self.nodes.push(latency);
88 id
89 }
90
91 /// Add a directed edge `from → to`.
92 ///
93 /// Returns [`CriticalPathError::InvalidNode`] if either node index is out
94 /// of range.
95 pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<(), CriticalPathError> {
96 let n = self.nodes.len();
97 if from >= n {
98 return Err(CriticalPathError::InvalidNode(from));
99 }
100 if to >= n {
101 return Err(CriticalPathError::InvalidNode(to));
102 }
103 self.edges.push((from, to));
104 Ok(())
105 }
106
107 /// Number of nodes in the graph.
108 pub fn num_nodes(&self) -> usize {
109 self.nodes.len()
110 }
111
112 /// Number of edges in the graph.
113 pub fn num_edges(&self) -> usize {
114 self.edges.len()
115 }
116}
117
118// ─────────────────────────────────────────────────────────────────────────────
119// Result types
120// ─────────────────────────────────────────────────────────────────────────────
121
122/// The critical-path analysis report.
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct CriticalPathReport {
125 /// Ordered sequence of [`NodeId`]s on the critical path, from source to
126 /// sink (inclusive). Empty when the graph has no nodes.
127 pub nodes: Vec<NodeId>,
128 /// Sum of node latencies along the critical path, in nanoseconds.
129 pub total_latency_ns: u64,
130 /// The node with the highest individual latency on the critical path.
131 /// Zero when `nodes` is empty.
132 pub bottleneck: NodeId,
133}
134
135/// Warning emitted when a node has no latency annotation.
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct MissingCostWarning {
138 /// The node whose `latency_ns` was `None`.
139 pub node_id: NodeId,
140}
141
142/// Combined result: analysis report plus any latency-annotation warnings.
143#[derive(Debug, Clone)]
144pub struct CriticalPathResult {
145 pub report: CriticalPathReport,
146 pub warnings: Vec<MissingCostWarning>,
147}
148
149/// Errors returned by [`critical_path`].
150#[derive(Error, Debug, Clone, PartialEq, Eq)]
151pub enum CriticalPathError {
152 /// The graph contains a directed cycle; critical-path analysis requires a
153 /// DAG. The enclosed string names the nodes that could not be processed.
154 #[error("Cycle detected; these nodes were not reachable via topological sort: {0}")]
155 CycleDetected(String),
156
157 /// An edge references a node index that does not exist in the graph.
158 #[error("Edge references out-of-range node id {0}")]
159 InvalidNode(NodeId),
160}
161
162// ─────────────────────────────────────────────────────────────────────────────
163// Core algorithm
164// ─────────────────────────────────────────────────────────────────────────────
165
166/// Compute the critical path of `graph`.
167///
168/// Returns `Ok(CriticalPathResult)` for valid DAGs, or
169/// `Err(CriticalPathError::CycleDetected)` if the graph is cyclic.
170///
171/// # Algorithm
172///
173/// 1. Build forward adjacency list and reverse adjacency list (predecessor map)
174/// together with in-degree counts.
175/// 2. Run Kahn's BFS topological sort. If any nodes are left unprocessed,
176/// a cycle exists.
177/// 3. DP over topo order: `dist[v] = cost(v) + max(dist[u] for u in pred(v))`.
178/// Track the predecessor that achieved the maximum for path reconstruction.
179/// 4. The node with the maximum `dist` value is the end of the critical path.
180/// 5. Walk predecessors back to reconstruct the path, then reverse it.
181pub fn critical_path(graph: &InferenceGraph) -> Result<CriticalPathResult, CriticalPathError> {
182 let n = graph.num_nodes();
183
184 // Empty graph — trivial result.
185 if n == 0 {
186 return Ok(CriticalPathResult {
187 report: CriticalPathReport {
188 nodes: vec![],
189 total_latency_ns: 0,
190 bottleneck: 0,
191 },
192 warnings: vec![],
193 });
194 }
195
196 // ── 1. Build adjacency structures ──────────────────────────────────────
197 //
198 // `succ[u]` = list of nodes that depend on u (successors).
199 // `pred[v]` = list of nodes that v depends on (predecessors).
200 // `in_degree[v]` = number of predecessors.
201
202 let mut succ: Vec<Vec<NodeId>> = vec![vec![]; n];
203 let mut pred: Vec<Vec<NodeId>> = vec![vec![]; n];
204 let mut in_degree: Vec<usize> = vec![0; n];
205
206 for &(from, to) in &graph.edges {
207 // Edge validity was checked at add_edge time, but edges could have
208 // been added to the raw field directly; guard anyway.
209 if from >= n || to >= n {
210 return Err(CriticalPathError::InvalidNode(if from >= n {
211 from
212 } else {
213 to
214 }));
215 }
216 succ[from].push(to);
217 pred[to].push(from);
218 in_degree[to] += 1;
219 }
220
221 // ── 2. Collect per-node costs; emit warnings for missing annotations ───
222
223 let mut warnings: Vec<MissingCostWarning> = vec![];
224 let costs: Vec<u64> = graph
225 .nodes
226 .iter()
227 .enumerate()
228 .map(|(id, nl)| {
229 nl.latency_ns.unwrap_or_else(|| {
230 warnings.push(MissingCostWarning { node_id: id });
231 1
232 })
233 })
234 .collect();
235
236 // ── 3. Kahn's BFS topological sort ─────────────────────────────────────
237
238 let mut queue: VecDeque<NodeId> = VecDeque::new();
239 for v in 0..n {
240 if in_degree[v] == 0 {
241 queue.push_back(v);
242 }
243 }
244
245 let mut topo_order: Vec<NodeId> = Vec::with_capacity(n);
246 // Work on a mutable copy of in-degrees so we can decrement during BFS.
247 let mut remaining_in: Vec<usize> = in_degree.clone();
248
249 while let Some(u) = queue.pop_front() {
250 topo_order.push(u);
251 for &v in &succ[u] {
252 remaining_in[v] -= 1;
253 if remaining_in[v] == 0 {
254 queue.push_back(v);
255 }
256 }
257 }
258
259 if topo_order.len() != n {
260 // Some nodes were not processed — there is a cycle.
261 let cyclic: Vec<String> = (0..n)
262 .filter(|&v| !topo_order.contains(&v))
263 .map(|v| v.to_string())
264 .collect();
265 return Err(CriticalPathError::CycleDetected(cyclic.join(", ")));
266 }
267
268 // ── 4. DP: longest path in topo order ──────────────────────────────────
269 //
270 // `dist[v]` = maximum accumulated latency of any path ending at `v`
271 // (including v's own cost).
272 // `best_pred[v]` = the predecessor that achieved `dist[v]`, or `None` for
273 // source nodes.
274
275 let mut dist: Vec<u64> = vec![0; n];
276 let mut best_pred: Vec<Option<NodeId>> = vec![None; n];
277
278 for &v in &topo_order {
279 // Start with just this node's own cost.
280 dist[v] = costs[v];
281 best_pred[v] = None;
282
283 // Extend the longest predecessor path.
284 for &u in &pred[v] {
285 let candidate = dist[u].saturating_add(costs[v]);
286 if candidate > dist[v] {
287 dist[v] = candidate;
288 best_pred[v] = Some(u);
289 }
290 }
291 }
292
293 // ── 5. Find the sink with the maximum distance ─────────────────────────
294
295 let (end_node, &max_dist) = dist
296 .iter()
297 .enumerate()
298 .max_by_key(|&(_, d)| d)
299 .unwrap_or((0, &0)); // Safety: n > 0 so the iterator is non-empty.
300
301 // ── 6. Reconstruct the path by walking back through best_pred ──────────
302
303 let mut path: Vec<NodeId> = vec![];
304 let mut current = end_node;
305 loop {
306 path.push(current);
307 match best_pred[current] {
308 Some(prev) => current = prev,
309 None => break,
310 }
311 }
312 path.reverse();
313
314 // ── 7. Identify bottleneck: node on path with highest individual cost ──
315
316 let bottleneck = path
317 .iter()
318 .copied()
319 .max_by_key(|&v| costs[v])
320 .unwrap_or(end_node);
321
322 Ok(CriticalPathResult {
323 report: CriticalPathReport {
324 nodes: path,
325 total_latency_ns: max_dist,
326 bottleneck,
327 },
328 warnings,
329 })
330}
331
332// ─────────────────────────────────────────────────────────────────────────────
333// Tests
334// ─────────────────────────────────────────────────────────────────────────────
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 // Helper: build a graph from a node-latency list and an edge list.
341 fn build_graph(latencies: &[Option<u64>], edges: &[(usize, usize)]) -> InferenceGraph {
342 let mut g = InferenceGraph::new();
343 for &lat in latencies {
344 g.add_node(NodeLatency { latency_ns: lat });
345 }
346 for &(from, to) in edges {
347 g.add_edge(from, to).expect("valid edge");
348 }
349 g
350 }
351
352 // ── Test 1: linear chain A→B→C ────────────────────────────────────────
353
354 #[test]
355 fn test_linear_chain() {
356 // A(10) → B(20) → C(5) — only one path, total = 35
357 let g = build_graph(&[Some(10), Some(20), Some(5)], &[(0, 1), (1, 2)]);
358 let res = critical_path(&g).expect("no cycle");
359
360 assert_eq!(res.report.nodes, vec![0, 1, 2]);
361 assert_eq!(res.report.total_latency_ns, 35);
362 assert_eq!(res.report.bottleneck, 1); // node 1 has cost 20
363 assert!(res.warnings.is_empty());
364 }
365
366 // ── Test 2: diamond — longer branch wins ─────────────────────────────
367
368 #[test]
369 fn test_diamond_longer_branch_wins() {
370 // A(1) → B(100) → D(1)
371 // → C(1) → D
372 // Critical path: A→B→D, total = 102
373 let g = build_graph(
374 &[Some(1), Some(100), Some(1), Some(1)],
375 &[(0, 1), (0, 2), (1, 3), (2, 3)],
376 );
377 let res = critical_path(&g).expect("no cycle");
378
379 assert_eq!(res.report.nodes, vec![0, 1, 3]);
380 assert_eq!(res.report.total_latency_ns, 102);
381 assert_eq!(res.report.bottleneck, 1); // node 1 has cost 100
382 assert!(res.warnings.is_empty());
383 }
384
385 // ── Test 3: single node ───────────────────────────────────────────────
386
387 #[test]
388 fn test_single_node() {
389 let g = build_graph(&[Some(42)], &[]);
390 let res = critical_path(&g).expect("no cycle");
391
392 assert_eq!(res.report.nodes, vec![0]);
393 assert_eq!(res.report.total_latency_ns, 42);
394 assert_eq!(res.report.bottleneck, 0);
395 assert!(res.warnings.is_empty());
396 }
397
398 // ── Test 4: missing latency annotations emit warnings ─────────────────
399
400 #[test]
401 fn test_missing_latency_warning() {
402 // A(None) → B(None) → C(None)
403 // Each falls back to 1 ns → total = 3, warnings for all three.
404 let g = build_graph(&[None, None, None], &[(0, 1), (1, 2)]);
405 let res = critical_path(&g).expect("no cycle");
406
407 assert_eq!(res.report.total_latency_ns, 3);
408 assert_eq!(res.warnings.len(), 3);
409 let warned_ids: Vec<NodeId> = res.warnings.iter().map(|w| w.node_id).collect();
410 assert!(warned_ids.contains(&0));
411 assert!(warned_ids.contains(&1));
412 assert!(warned_ids.contains(&2));
413 }
414
415 // ── Test 5: empty graph ───────────────────────────────────────────────
416
417 #[test]
418 fn test_empty_graph() {
419 let g = InferenceGraph::new();
420 let res = critical_path(&g).expect("no cycle");
421
422 assert!(res.report.nodes.is_empty());
423 assert_eq!(res.report.total_latency_ns, 0);
424 assert_eq!(res.report.bottleneck, 0);
425 assert!(res.warnings.is_empty());
426 }
427
428 // ── Test 6: cycle detection returns an error ──────────────────────────
429
430 #[test]
431 fn test_cycle_detected() {
432 // A → B → C → A (cycle)
433 let g = build_graph(&[Some(1), Some(1), Some(1)], &[(0, 1), (1, 2), (2, 0)]);
434 let err = critical_path(&g).expect_err("should detect cycle");
435 matches!(err, CriticalPathError::CycleDetected(_));
436 }
437
438 // ── Test 7: parallel branches without shared sink ─────────────────────
439
440 #[test]
441 fn test_parallel_branches() {
442 // Two independent chains: A(5)→B(10) and C(1)→D(3)
443 // Longest path ends at B with dist 15.
444 let g = build_graph(&[Some(5), Some(10), Some(1), Some(3)], &[(0, 1), (2, 3)]);
445 let res = critical_path(&g).expect("no cycle");
446
447 assert_eq!(res.report.total_latency_ns, 15);
448 assert_eq!(*res.report.nodes.last().expect("non-empty"), 1);
449 }
450
451 // ── Test 8: wide graph — fan-out then fan-in ──────────────────────────
452
453 #[test]
454 fn test_fan_out_fan_in() {
455 // root(1) → mid0(2) → sink(1)
456 // → mid1(5) → sink
457 // → mid2(3) → sink
458 // Longest: root→mid1→sink = 1+5+1 = 7
459 let g = build_graph(
460 &[Some(1), Some(2), Some(5), Some(3), Some(1)],
461 &[(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)],
462 );
463 let res = critical_path(&g).expect("no cycle");
464
465 assert_eq!(res.report.total_latency_ns, 7);
466 assert_eq!(res.report.nodes, vec![0, 2, 4]);
467 assert_eq!(res.report.bottleneck, 2); // cost 5
468 }
469
470 // ── Test 9: invalid edge returns error ────────────────────────────────
471
472 #[test]
473 fn test_invalid_edge() {
474 let mut g = InferenceGraph::new();
475 g.add_node(NodeLatency::new(10));
476 let err = g.add_edge(0, 5).expect_err("node 5 does not exist");
477 matches!(err, CriticalPathError::InvalidNode(5));
478 }
479
480 // ── Test 10: mixed latencies, partially annotated ─────────────────────
481
482 #[test]
483 fn test_mixed_latencies() {
484 // A(10) → B(None=1 fallback) → C(50)
485 // total = 10+1+50 = 61, 1 warning for B
486 let g = build_graph(&[Some(10), None, Some(50)], &[(0, 1), (1, 2)]);
487 let res = critical_path(&g).expect("no cycle");
488
489 assert_eq!(res.report.total_latency_ns, 61);
490 assert_eq!(res.report.bottleneck, 2); // C has cost 50
491 assert_eq!(res.warnings.len(), 1);
492 assert_eq!(res.warnings[0].node_id, 1);
493 }
494}