Skip to main content

tensorlogic_infer/
execution_plan.rs

1//! Execution plan formatting and resource timeline visualization.
2//!
3//! Provides human-readable views of execution schedules, parallel opportunity
4//! analysis, and memory usage timelines.
5
6use std::collections::HashMap;
7use std::fmt::Write;
8
9/// A single step in an execution plan.
10#[derive(Debug, Clone)]
11pub struct PlanStep {
12    /// Step index (execution order)
13    pub index: usize,
14    /// Operation name/description
15    pub operation: String,
16    /// Input tensor names
17    pub inputs: Vec<String>,
18    /// Output tensor name
19    pub output: String,
20    /// Estimated memory usage in bytes
21    pub estimated_memory_bytes: usize,
22    /// Estimated FLOPs
23    pub estimated_flops: u64,
24    /// Whether this step can run in parallel with the previous step
25    pub parallelizable: bool,
26    /// Level in the dependency graph (0 = no dependencies)
27    pub dependency_level: usize,
28}
29
30impl PlanStep {
31    /// Create a new plan step with default values for optional fields.
32    pub fn new(index: usize, operation: impl Into<String>, output: impl Into<String>) -> Self {
33        PlanStep {
34            index,
35            operation: operation.into(),
36            inputs: Vec::new(),
37            output: output.into(),
38            estimated_memory_bytes: 0,
39            estimated_flops: 0,
40            parallelizable: false,
41            dependency_level: 0,
42        }
43    }
44
45    /// Set the input tensor names.
46    pub fn with_inputs(mut self, inputs: Vec<String>) -> Self {
47        self.inputs = inputs;
48        self
49    }
50
51    /// Set the estimated memory usage in bytes.
52    pub fn with_memory(mut self, bytes: usize) -> Self {
53        self.estimated_memory_bytes = bytes;
54        self
55    }
56
57    /// Set the estimated FLOPs.
58    pub fn with_flops(mut self, flops: u64) -> Self {
59        self.estimated_flops = flops;
60        self
61    }
62
63    /// Set whether this step can run in parallel.
64    pub fn with_parallel(mut self, p: bool) -> Self {
65        self.parallelizable = p;
66        self
67    }
68
69    /// Set the dependency level.
70    pub fn with_level(mut self, l: usize) -> Self {
71        self.dependency_level = l;
72        self
73    }
74}
75
76/// A complete execution plan containing ordered steps with dependency and
77/// resource metadata.
78#[derive(Debug, Clone, Default)]
79pub struct ExecutionPlan {
80    /// The ordered steps of this execution plan.
81    pub steps: Vec<PlanStep>,
82}
83
84impl ExecutionPlan {
85    /// Create a new empty execution plan.
86    pub fn new() -> Self {
87        Self::default()
88    }
89
90    /// Add a step to the execution plan.
91    pub fn add_step(&mut self, step: PlanStep) {
92        self.steps.push(step);
93    }
94
95    /// Total estimated FLOPs across all steps.
96    pub fn total_flops(&self) -> u64 {
97        self.steps.iter().map(|s| s.estimated_flops).sum()
98    }
99
100    /// Peak estimated memory (sum of all live tensors at the busiest level).
101    ///
102    /// Groups steps by dependency level, sums memory per level, and returns
103    /// the maximum.
104    pub fn peak_memory(&self) -> usize {
105        let mut level_mem: HashMap<usize, usize> = HashMap::new();
106        for step in &self.steps {
107            *level_mem.entry(step.dependency_level).or_insert(0) += step.estimated_memory_bytes;
108        }
109        level_mem.values().copied().max().unwrap_or(0)
110    }
111
112    /// Number of steps that can be parallelized.
113    pub fn parallel_count(&self) -> usize {
114        self.steps.iter().filter(|s| s.parallelizable).count()
115    }
116
117    /// Maximum dependency depth (critical path length).
118    ///
119    /// Returns the number of distinct dependency levels, i.e. max level + 1.
120    pub fn critical_path_length(&self) -> usize {
121        self.steps
122            .iter()
123            .map(|s| s.dependency_level)
124            .max()
125            .map(|m| m + 1)
126            .unwrap_or(0)
127    }
128
129    /// Theoretical speedup from parallelism (total_steps / critical_path_length).
130    pub fn parallel_speedup(&self) -> f64 {
131        let cpl = self.critical_path_length();
132        if cpl == 0 {
133            return 1.0;
134        }
135        self.steps.len() as f64 / cpl as f64
136    }
137}
138
139/// Formatter for rendering execution plans as human-readable strings.
140pub struct PlanFormatter;
141
142impl PlanFormatter {
143    /// Format the plan as a table with columns for step index, operation,
144    /// output, dependency level, memory, and parallelizability.
145    pub fn format_table(plan: &ExecutionPlan) -> String {
146        let mut out = String::new();
147        let _ = writeln!(out, "{:-<80}", "");
148        let _ = writeln!(
149            out,
150            "{:<5} {:<20} {:<20} {:<8} {:<10} {:<5}",
151            "Step", "Operation", "Output", "Level", "Memory", "Par?"
152        );
153        let _ = writeln!(out, "{:-<80}", "");
154        for step in &plan.steps {
155            let mem_str = format_bytes(step.estimated_memory_bytes);
156            let par = if step.parallelizable { "yes" } else { "no" };
157            let _ = writeln!(
158                out,
159                "{:<5} {:<20} {:<20} {:<8} {:<10} {:<5}",
160                step.index,
161                truncate(&step.operation, 19),
162                truncate(&step.output, 19),
163                step.dependency_level,
164                mem_str,
165                par
166            );
167        }
168        let _ = writeln!(out, "{:-<80}", "");
169        let _ = writeln!(
170            out,
171            "Total steps: {} | Critical path: {} | Parallel speedup: {:.1}x",
172            plan.steps.len(),
173            plan.critical_path_length(),
174            plan.parallel_speedup()
175        );
176        let _ = writeln!(
177            out,
178            "Total FLOPs: {} | Peak memory: {}",
179            plan.total_flops(),
180            format_bytes(plan.peak_memory())
181        );
182        out
183    }
184
185    /// Format the plan as a level-grouped tree showing parallelism
186    /// opportunities.
187    pub fn format_tree(plan: &ExecutionPlan) -> String {
188        let mut out = String::new();
189        let max_level = plan
190            .steps
191            .iter()
192            .map(|s| s.dependency_level)
193            .max()
194            .unwrap_or(0);
195        for level in 0..=max_level {
196            let steps_at_level: Vec<_> = plan
197                .steps
198                .iter()
199                .filter(|s| s.dependency_level == level)
200                .collect();
201            let _ = writeln!(
202                out,
203                "Level {} ({} ops{}):",
204                level,
205                steps_at_level.len(),
206                if steps_at_level.len() > 1 {
207                    " \u{2014} parallelizable"
208                } else {
209                    ""
210                }
211            );
212            for step in steps_at_level {
213                let _ = writeln!(
214                    out,
215                    "  [{}] {} \u{2192} {}",
216                    step.index, step.operation, step.output
217                );
218            }
219        }
220        out
221    }
222}
223
224/// A single entry in a memory timeline, tracking allocations over time.
225#[derive(Debug, Clone)]
226pub struct MemoryTimelineEntry {
227    /// The step index this entry corresponds to.
228    pub step: usize,
229    /// Bytes allocated at this step.
230    pub allocated_bytes: usize,
231    /// Bytes freed at this step.
232    pub freed_bytes: usize,
233    /// Total live bytes after this step.
234    pub live_bytes: usize,
235}
236
237/// Compute a memory timeline from an execution plan.
238///
239/// Produces one [`MemoryTimelineEntry`] per step, tracking cumulative
240/// allocations. In this simplified model no memory is freed between steps.
241pub fn compute_memory_timeline(plan: &ExecutionPlan) -> Vec<MemoryTimelineEntry> {
242    let mut live = 0usize;
243    plan.steps
244        .iter()
245        .map(|step| {
246            live = live.saturating_add(step.estimated_memory_bytes);
247            MemoryTimelineEntry {
248                step: step.index,
249                allocated_bytes: step.estimated_memory_bytes,
250                freed_bytes: 0,
251                live_bytes: live,
252            }
253        })
254        .collect()
255}
256
257/// Format a byte count as a human-readable string (B / KB / MB).
258fn format_bytes(bytes: usize) -> String {
259    if bytes < 1024 {
260        format!("{}B", bytes)
261    } else if bytes < 1024 * 1024 {
262        format!("{:.1}KB", bytes as f64 / 1024.0)
263    } else {
264        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
265    }
266}
267
268/// Truncate a string to at most `max` characters, appending an ellipsis if
269/// truncated.
270fn truncate(s: &str, max: usize) -> String {
271    if s.len() <= max {
272        s.to_string()
273    } else {
274        let boundary = max.saturating_sub(1);
275        // Find a valid char boundary at or before the target position
276        let end = s
277            .char_indices()
278            .take_while(|&(i, _)| i < boundary)
279            .last()
280            .map(|(i, c)| i + c.len_utf8())
281            .unwrap_or(0);
282        format!("{}\u{2026}", &s[..end])
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn sample_plan() -> ExecutionPlan {
291        let mut plan = ExecutionPlan::new();
292        plan.add_step(
293            PlanStep::new(0, "matmul", "t0")
294                .with_inputs(vec!["a".into(), "b".into()])
295                .with_memory(1024)
296                .with_flops(2000)
297                .with_level(0),
298        );
299        plan.add_step(
300            PlanStep::new(1, "relu", "t1")
301                .with_inputs(vec!["t0".into()])
302                .with_memory(512)
303                .with_flops(500)
304                .with_parallel(true)
305                .with_level(0),
306        );
307        plan.add_step(
308            PlanStep::new(2, "add", "t2")
309                .with_inputs(vec!["t0".into(), "t1".into()])
310                .with_memory(2048)
311                .with_flops(1000)
312                .with_level(1),
313        );
314        plan
315    }
316
317    #[test]
318    fn test_plan_step_new() {
319        let step = PlanStep::new(0, "matmul", "out");
320        assert_eq!(step.index, 0);
321        assert_eq!(step.operation, "matmul");
322        assert_eq!(step.output, "out");
323        assert!(step.inputs.is_empty());
324        assert_eq!(step.estimated_memory_bytes, 0);
325        assert_eq!(step.estimated_flops, 0);
326        assert!(!step.parallelizable);
327        assert_eq!(step.dependency_level, 0);
328    }
329
330    #[test]
331    fn test_plan_step_builder() {
332        let step = PlanStep::new(1, "conv2d", "feat")
333            .with_inputs(vec!["img".into()])
334            .with_memory(4096)
335            .with_flops(8000)
336            .with_parallel(true)
337            .with_level(2);
338        assert_eq!(step.index, 1);
339        assert_eq!(step.inputs, vec!["img".to_string()]);
340        assert_eq!(step.estimated_memory_bytes, 4096);
341        assert_eq!(step.estimated_flops, 8000);
342        assert!(step.parallelizable);
343        assert_eq!(step.dependency_level, 2);
344    }
345
346    #[test]
347    fn test_plan_new_empty() {
348        let plan = ExecutionPlan::new();
349        assert!(plan.steps.is_empty());
350        assert_eq!(plan.total_flops(), 0);
351        assert_eq!(plan.peak_memory(), 0);
352        assert_eq!(plan.critical_path_length(), 0);
353    }
354
355    #[test]
356    fn test_plan_add_step() {
357        let mut plan = ExecutionPlan::new();
358        assert_eq!(plan.steps.len(), 0);
359        plan.add_step(PlanStep::new(0, "op", "out"));
360        assert_eq!(plan.steps.len(), 1);
361        plan.add_step(PlanStep::new(1, "op2", "out2"));
362        assert_eq!(plan.steps.len(), 2);
363    }
364
365    #[test]
366    fn test_plan_total_flops() {
367        let plan = sample_plan();
368        // 2000 + 500 + 1000 = 3500
369        assert_eq!(plan.total_flops(), 3500);
370    }
371
372    #[test]
373    fn test_plan_peak_memory() {
374        let plan = sample_plan();
375        // Level 0: 1024 + 512 = 1536, Level 1: 2048 => peak = 2048
376        assert_eq!(plan.peak_memory(), 2048);
377    }
378
379    #[test]
380    fn test_plan_parallel_count() {
381        let plan = sample_plan();
382        // Only step 1 is parallelizable
383        assert_eq!(plan.parallel_count(), 1);
384    }
385
386    #[test]
387    fn test_plan_critical_path() {
388        let plan = sample_plan();
389        // Max level is 1, so critical path = 2
390        assert_eq!(plan.critical_path_length(), 2);
391    }
392
393    #[test]
394    fn test_plan_parallel_speedup() {
395        let plan = sample_plan();
396        // 3 steps / 2 levels = 1.5
397        let speedup = plan.parallel_speedup();
398        assert!((speedup - 1.5).abs() < 1e-9);
399    }
400
401    #[test]
402    fn test_format_table_header() {
403        let plan = sample_plan();
404        let table = PlanFormatter::format_table(&plan);
405        assert!(table.contains("Step"));
406        assert!(table.contains("Operation"));
407        assert!(table.contains("Output"));
408        assert!(table.contains("Level"));
409        assert!(table.contains("Memory"));
410        assert!(table.contains("Par?"));
411    }
412
413    #[test]
414    fn test_format_table_entries() {
415        let plan = sample_plan();
416        let table = PlanFormatter::format_table(&plan);
417        // Step indices should appear
418        assert!(table.contains("0"));
419        assert!(table.contains("1"));
420        assert!(table.contains("2"));
421        // Operation names
422        assert!(table.contains("matmul"));
423        assert!(table.contains("relu"));
424        assert!(table.contains("add"));
425    }
426
427    #[test]
428    fn test_format_table_summary() {
429        let plan = sample_plan();
430        let table = PlanFormatter::format_table(&plan);
431        assert!(table.contains("Total steps: 3"));
432        assert!(table.contains("Critical path: 2"));
433        assert!(table.contains("Parallel speedup: 1.5x"));
434        assert!(table.contains("Total FLOPs: 3500"));
435    }
436
437    #[test]
438    fn test_format_tree_levels() {
439        let plan = sample_plan();
440        let tree = PlanFormatter::format_tree(&plan);
441        assert!(tree.contains("Level 0"));
442        assert!(tree.contains("Level 1"));
443        // Step 0 and 1 at level 0
444        assert!(tree.contains("[0] matmul"));
445        assert!(tree.contains("[1] relu"));
446        // Step 2 at level 1
447        assert!(tree.contains("[2] add"));
448    }
449
450    #[test]
451    fn test_format_tree_parallel_note() {
452        let plan = sample_plan();
453        let tree = PlanFormatter::format_tree(&plan);
454        // Level 0 has 2 ops, should show parallelizable note
455        assert!(tree.contains("parallelizable"));
456        // Level 1 has 1 op, should NOT show parallelizable for that line
457        let lines: Vec<&str> = tree.lines().collect();
458        let level1_line = lines
459            .iter()
460            .find(|l| l.starts_with("Level 1"))
461            .expect("Level 1 line must exist");
462        assert!(!level1_line.contains("parallelizable"));
463    }
464
465    #[test]
466    fn test_memory_timeline_accumulates() {
467        let plan = sample_plan();
468        let timeline = compute_memory_timeline(&plan);
469        // live_bytes should monotonically increase
470        assert_eq!(timeline[0].live_bytes, 1024);
471        assert_eq!(timeline[1].live_bytes, 1536);
472        assert_eq!(timeline[2].live_bytes, 3584);
473    }
474
475    #[test]
476    fn test_memory_timeline_length() {
477        let plan = sample_plan();
478        let timeline = compute_memory_timeline(&plan);
479        assert_eq!(timeline.len(), plan.steps.len());
480    }
481
482    #[test]
483    fn test_format_bytes_b() {
484        assert_eq!(format_bytes(512), "512B");
485        assert_eq!(format_bytes(0), "0B");
486        assert_eq!(format_bytes(1023), "1023B");
487    }
488
489    #[test]
490    fn test_format_bytes_kb() {
491        assert_eq!(format_bytes(2048), "2.0KB");
492        assert_eq!(format_bytes(1024), "1.0KB");
493    }
494}