1use std::collections::HashMap;
7use std::fmt::Write;
8
9#[derive(Debug, Clone)]
11pub struct PlanStep {
12 pub index: usize,
14 pub operation: String,
16 pub inputs: Vec<String>,
18 pub output: String,
20 pub estimated_memory_bytes: usize,
22 pub estimated_flops: u64,
24 pub parallelizable: bool,
26 pub dependency_level: usize,
28}
29
30impl PlanStep {
31 pub fn new(index: usize, operation: impl Into<String>, output: impl Into<String>) -> Self {
33 PlanStep {
34 index,
35 operation: operation.into(),
36 inputs: Vec::new(),
37 output: output.into(),
38 estimated_memory_bytes: 0,
39 estimated_flops: 0,
40 parallelizable: false,
41 dependency_level: 0,
42 }
43 }
44
45 pub fn with_inputs(mut self, inputs: Vec<String>) -> Self {
47 self.inputs = inputs;
48 self
49 }
50
51 pub fn with_memory(mut self, bytes: usize) -> Self {
53 self.estimated_memory_bytes = bytes;
54 self
55 }
56
57 pub fn with_flops(mut self, flops: u64) -> Self {
59 self.estimated_flops = flops;
60 self
61 }
62
63 pub fn with_parallel(mut self, p: bool) -> Self {
65 self.parallelizable = p;
66 self
67 }
68
69 pub fn with_level(mut self, l: usize) -> Self {
71 self.dependency_level = l;
72 self
73 }
74}
75
76#[derive(Debug, Clone, Default)]
79pub struct ExecutionPlan {
80 pub steps: Vec<PlanStep>,
82}
83
84impl ExecutionPlan {
85 pub fn new() -> Self {
87 Self::default()
88 }
89
90 pub fn add_step(&mut self, step: PlanStep) {
92 self.steps.push(step);
93 }
94
95 pub fn total_flops(&self) -> u64 {
97 self.steps.iter().map(|s| s.estimated_flops).sum()
98 }
99
100 pub fn peak_memory(&self) -> usize {
105 let mut level_mem: HashMap<usize, usize> = HashMap::new();
106 for step in &self.steps {
107 *level_mem.entry(step.dependency_level).or_insert(0) += step.estimated_memory_bytes;
108 }
109 level_mem.values().copied().max().unwrap_or(0)
110 }
111
112 pub fn parallel_count(&self) -> usize {
114 self.steps.iter().filter(|s| s.parallelizable).count()
115 }
116
117 pub fn critical_path_length(&self) -> usize {
121 self.steps
122 .iter()
123 .map(|s| s.dependency_level)
124 .max()
125 .map(|m| m + 1)
126 .unwrap_or(0)
127 }
128
129 pub fn parallel_speedup(&self) -> f64 {
131 let cpl = self.critical_path_length();
132 if cpl == 0 {
133 return 1.0;
134 }
135 self.steps.len() as f64 / cpl as f64
136 }
137}
138
139pub struct PlanFormatter;
141
142impl PlanFormatter {
143 pub fn format_table(plan: &ExecutionPlan) -> String {
146 let mut out = String::new();
147 let _ = writeln!(out, "{:-<80}", "");
148 let _ = writeln!(
149 out,
150 "{:<5} {:<20} {:<20} {:<8} {:<10} {:<5}",
151 "Step", "Operation", "Output", "Level", "Memory", "Par?"
152 );
153 let _ = writeln!(out, "{:-<80}", "");
154 for step in &plan.steps {
155 let mem_str = format_bytes(step.estimated_memory_bytes);
156 let par = if step.parallelizable { "yes" } else { "no" };
157 let _ = writeln!(
158 out,
159 "{:<5} {:<20} {:<20} {:<8} {:<10} {:<5}",
160 step.index,
161 truncate(&step.operation, 19),
162 truncate(&step.output, 19),
163 step.dependency_level,
164 mem_str,
165 par
166 );
167 }
168 let _ = writeln!(out, "{:-<80}", "");
169 let _ = writeln!(
170 out,
171 "Total steps: {} | Critical path: {} | Parallel speedup: {:.1}x",
172 plan.steps.len(),
173 plan.critical_path_length(),
174 plan.parallel_speedup()
175 );
176 let _ = writeln!(
177 out,
178 "Total FLOPs: {} | Peak memory: {}",
179 plan.total_flops(),
180 format_bytes(plan.peak_memory())
181 );
182 out
183 }
184
185 pub fn format_tree(plan: &ExecutionPlan) -> String {
188 let mut out = String::new();
189 let max_level = plan
190 .steps
191 .iter()
192 .map(|s| s.dependency_level)
193 .max()
194 .unwrap_or(0);
195 for level in 0..=max_level {
196 let steps_at_level: Vec<_> = plan
197 .steps
198 .iter()
199 .filter(|s| s.dependency_level == level)
200 .collect();
201 let _ = writeln!(
202 out,
203 "Level {} ({} ops{}):",
204 level,
205 steps_at_level.len(),
206 if steps_at_level.len() > 1 {
207 " \u{2014} parallelizable"
208 } else {
209 ""
210 }
211 );
212 for step in steps_at_level {
213 let _ = writeln!(
214 out,
215 " [{}] {} \u{2192} {}",
216 step.index, step.operation, step.output
217 );
218 }
219 }
220 out
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct MemoryTimelineEntry {
227 pub step: usize,
229 pub allocated_bytes: usize,
231 pub freed_bytes: usize,
233 pub live_bytes: usize,
235}
236
237pub fn compute_memory_timeline(plan: &ExecutionPlan) -> Vec<MemoryTimelineEntry> {
242 let mut live = 0usize;
243 plan.steps
244 .iter()
245 .map(|step| {
246 live = live.saturating_add(step.estimated_memory_bytes);
247 MemoryTimelineEntry {
248 step: step.index,
249 allocated_bytes: step.estimated_memory_bytes,
250 freed_bytes: 0,
251 live_bytes: live,
252 }
253 })
254 .collect()
255}
256
257fn format_bytes(bytes: usize) -> String {
259 if bytes < 1024 {
260 format!("{}B", bytes)
261 } else if bytes < 1024 * 1024 {
262 format!("{:.1}KB", bytes as f64 / 1024.0)
263 } else {
264 format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
265 }
266}
267
268fn truncate(s: &str, max: usize) -> String {
271 if s.len() <= max {
272 s.to_string()
273 } else {
274 let boundary = max.saturating_sub(1);
275 let end = s
277 .char_indices()
278 .take_while(|&(i, _)| i < boundary)
279 .last()
280 .map(|(i, c)| i + c.len_utf8())
281 .unwrap_or(0);
282 format!("{}\u{2026}", &s[..end])
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn sample_plan() -> ExecutionPlan {
291 let mut plan = ExecutionPlan::new();
292 plan.add_step(
293 PlanStep::new(0, "matmul", "t0")
294 .with_inputs(vec!["a".into(), "b".into()])
295 .with_memory(1024)
296 .with_flops(2000)
297 .with_level(0),
298 );
299 plan.add_step(
300 PlanStep::new(1, "relu", "t1")
301 .with_inputs(vec!["t0".into()])
302 .with_memory(512)
303 .with_flops(500)
304 .with_parallel(true)
305 .with_level(0),
306 );
307 plan.add_step(
308 PlanStep::new(2, "add", "t2")
309 .with_inputs(vec!["t0".into(), "t1".into()])
310 .with_memory(2048)
311 .with_flops(1000)
312 .with_level(1),
313 );
314 plan
315 }
316
317 #[test]
318 fn test_plan_step_new() {
319 let step = PlanStep::new(0, "matmul", "out");
320 assert_eq!(step.index, 0);
321 assert_eq!(step.operation, "matmul");
322 assert_eq!(step.output, "out");
323 assert!(step.inputs.is_empty());
324 assert_eq!(step.estimated_memory_bytes, 0);
325 assert_eq!(step.estimated_flops, 0);
326 assert!(!step.parallelizable);
327 assert_eq!(step.dependency_level, 0);
328 }
329
330 #[test]
331 fn test_plan_step_builder() {
332 let step = PlanStep::new(1, "conv2d", "feat")
333 .with_inputs(vec!["img".into()])
334 .with_memory(4096)
335 .with_flops(8000)
336 .with_parallel(true)
337 .with_level(2);
338 assert_eq!(step.index, 1);
339 assert_eq!(step.inputs, vec!["img".to_string()]);
340 assert_eq!(step.estimated_memory_bytes, 4096);
341 assert_eq!(step.estimated_flops, 8000);
342 assert!(step.parallelizable);
343 assert_eq!(step.dependency_level, 2);
344 }
345
346 #[test]
347 fn test_plan_new_empty() {
348 let plan = ExecutionPlan::new();
349 assert!(plan.steps.is_empty());
350 assert_eq!(plan.total_flops(), 0);
351 assert_eq!(plan.peak_memory(), 0);
352 assert_eq!(plan.critical_path_length(), 0);
353 }
354
355 #[test]
356 fn test_plan_add_step() {
357 let mut plan = ExecutionPlan::new();
358 assert_eq!(plan.steps.len(), 0);
359 plan.add_step(PlanStep::new(0, "op", "out"));
360 assert_eq!(plan.steps.len(), 1);
361 plan.add_step(PlanStep::new(1, "op2", "out2"));
362 assert_eq!(plan.steps.len(), 2);
363 }
364
365 #[test]
366 fn test_plan_total_flops() {
367 let plan = sample_plan();
368 assert_eq!(plan.total_flops(), 3500);
370 }
371
372 #[test]
373 fn test_plan_peak_memory() {
374 let plan = sample_plan();
375 assert_eq!(plan.peak_memory(), 2048);
377 }
378
379 #[test]
380 fn test_plan_parallel_count() {
381 let plan = sample_plan();
382 assert_eq!(plan.parallel_count(), 1);
384 }
385
386 #[test]
387 fn test_plan_critical_path() {
388 let plan = sample_plan();
389 assert_eq!(plan.critical_path_length(), 2);
391 }
392
393 #[test]
394 fn test_plan_parallel_speedup() {
395 let plan = sample_plan();
396 let speedup = plan.parallel_speedup();
398 assert!((speedup - 1.5).abs() < 1e-9);
399 }
400
401 #[test]
402 fn test_format_table_header() {
403 let plan = sample_plan();
404 let table = PlanFormatter::format_table(&plan);
405 assert!(table.contains("Step"));
406 assert!(table.contains("Operation"));
407 assert!(table.contains("Output"));
408 assert!(table.contains("Level"));
409 assert!(table.contains("Memory"));
410 assert!(table.contains("Par?"));
411 }
412
413 #[test]
414 fn test_format_table_entries() {
415 let plan = sample_plan();
416 let table = PlanFormatter::format_table(&plan);
417 assert!(table.contains("0"));
419 assert!(table.contains("1"));
420 assert!(table.contains("2"));
421 assert!(table.contains("matmul"));
423 assert!(table.contains("relu"));
424 assert!(table.contains("add"));
425 }
426
427 #[test]
428 fn test_format_table_summary() {
429 let plan = sample_plan();
430 let table = PlanFormatter::format_table(&plan);
431 assert!(table.contains("Total steps: 3"));
432 assert!(table.contains("Critical path: 2"));
433 assert!(table.contains("Parallel speedup: 1.5x"));
434 assert!(table.contains("Total FLOPs: 3500"));
435 }
436
437 #[test]
438 fn test_format_tree_levels() {
439 let plan = sample_plan();
440 let tree = PlanFormatter::format_tree(&plan);
441 assert!(tree.contains("Level 0"));
442 assert!(tree.contains("Level 1"));
443 assert!(tree.contains("[0] matmul"));
445 assert!(tree.contains("[1] relu"));
446 assert!(tree.contains("[2] add"));
448 }
449
450 #[test]
451 fn test_format_tree_parallel_note() {
452 let plan = sample_plan();
453 let tree = PlanFormatter::format_tree(&plan);
454 assert!(tree.contains("parallelizable"));
456 let lines: Vec<&str> = tree.lines().collect();
458 let level1_line = lines
459 .iter()
460 .find(|l| l.starts_with("Level 1"))
461 .expect("Level 1 line must exist");
462 assert!(!level1_line.contains("parallelizable"));
463 }
464
465 #[test]
466 fn test_memory_timeline_accumulates() {
467 let plan = sample_plan();
468 let timeline = compute_memory_timeline(&plan);
469 assert_eq!(timeline[0].live_bytes, 1024);
471 assert_eq!(timeline[1].live_bytes, 1536);
472 assert_eq!(timeline[2].live_bytes, 3584);
473 }
474
475 #[test]
476 fn test_memory_timeline_length() {
477 let plan = sample_plan();
478 let timeline = compute_memory_timeline(&plan);
479 assert_eq!(timeline.len(), plan.steps.len());
480 }
481
482 #[test]
483 fn test_format_bytes_b() {
484 assert_eq!(format_bytes(512), "512B");
485 assert_eq!(format_bytes(0), "0B");
486 assert_eq!(format_bytes(1023), "1023B");
487 }
488
489 #[test]
490 fn test_format_bytes_kb() {
491 assert_eq!(format_bytes(2048), "2.0KB");
492 assert_eq!(format_bytes(1024), "1.0KB");
493 }
494}