1use std::collections::{HashMap, HashSet, VecDeque};
8
9use super::EinsumGraph;
10use crate::error::IrError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchedulingObjective {
15 MinimizeLatency,
17 MaximizeThroughput,
19 MinimizeMemory,
21 Balanced,
23 Pipeline,
25}
26
27#[derive(Debug, Clone)]
29pub struct ExecutionSchedule {
30 pub execution_order: Vec<usize>,
32 pub parallel_stages: Vec<Vec<usize>>,
34 pub stage_costs: Vec<f64>,
36 pub total_cost: f64,
38 pub peak_memory: usize,
40 pub objective: SchedulingObjective,
42}
43
44impl ExecutionSchedule {
45 pub fn new(objective: SchedulingObjective) -> Self {
47 Self {
48 execution_order: Vec::new(),
49 parallel_stages: Vec::new(),
50 stage_costs: Vec::new(),
51 total_cost: 0.0,
52 peak_memory: 0,
53 objective,
54 }
55 }
56
57 pub fn num_stages(&self) -> usize {
59 self.parallel_stages.len()
60 }
61
62 pub fn max_parallelism(&self) -> usize {
64 self.parallel_stages
65 .iter()
66 .map(|s| s.len())
67 .max()
68 .unwrap_or(0)
69 }
70
71 pub fn avg_parallelism(&self) -> f64 {
73 if self.parallel_stages.is_empty() {
74 return 0.0;
75 }
76 let total: usize = self.parallel_stages.iter().map(|s| s.len()).sum();
77 total as f64 / self.parallel_stages.len() as f64
78 }
79}
80
81pub struct GraphScheduler {
83 operation_costs: HashMap<usize, f64>,
85 tensor_memory: HashMap<usize, usize>,
87}
88
89impl GraphScheduler {
90 pub fn new() -> Self {
92 Self {
93 operation_costs: HashMap::new(),
94 tensor_memory: HashMap::new(),
95 }
96 }
97
98 pub fn set_operation_cost(&mut self, node_idx: usize, cost: f64) {
100 self.operation_costs.insert(node_idx, cost);
101 }
102
103 pub fn set_tensor_memory(&mut self, tensor_idx: usize, size: usize) {
105 self.tensor_memory.insert(tensor_idx, size);
106 }
107
108 pub fn schedule(
110 &self,
111 graph: &EinsumGraph,
112 objective: SchedulingObjective,
113 ) -> Result<ExecutionSchedule, IrError> {
114 match objective {
115 SchedulingObjective::MinimizeLatency => self.schedule_min_latency(graph),
116 SchedulingObjective::MaximizeThroughput => self.schedule_max_throughput(graph),
117 SchedulingObjective::MinimizeMemory => self.schedule_min_memory(graph),
118 SchedulingObjective::Balanced => self.schedule_balanced(graph),
119 SchedulingObjective::Pipeline => self.schedule_pipeline(graph),
120 }
121 }
122
123 fn schedule_min_latency(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
125 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
126
127 let dependencies = self.build_dependencies(graph);
129
130 let start_times = self.compute_start_times(graph, &dependencies);
132
133 let mut stages: HashMap<usize, Vec<usize>> = HashMap::new();
135 for (node_idx, &start_time) in start_times.iter().enumerate() {
136 stages
137 .entry(start_time as usize)
138 .or_default()
139 .push(node_idx);
140 }
141
142 let mut stage_indices: Vec<_> = stages.keys().copied().collect();
144 stage_indices.sort_unstable();
145
146 for stage_idx in stage_indices {
147 if let Some(nodes) = stages.get(&stage_idx) {
148 let stage_cost = nodes
149 .iter()
150 .map(|&idx| self.get_operation_cost(idx))
151 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
152 .unwrap_or(0.0);
153
154 schedule.parallel_stages.push(nodes.clone());
155 schedule.stage_costs.push(stage_cost);
156 schedule.total_cost += stage_cost;
157
158 for &node in nodes {
159 schedule.execution_order.push(node);
160 }
161 }
162 }
163
164 Ok(schedule)
165 }
166
167 fn schedule_max_throughput(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
169 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MaximizeThroughput);
170
171 let dependencies = self.build_dependencies(graph);
173 #[allow(clippy::unnecessary_map_or)]
174 let mut ready: Vec<usize> = (0..graph.nodes.len())
175 .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
176 .collect();
177
178 ready.sort_by(|&a, &b| {
180 let cost_a = self.get_operation_cost(a);
181 let cost_b = self.get_operation_cost(b);
182 cost_b
183 .partial_cmp(&cost_a)
184 .unwrap_or(std::cmp::Ordering::Equal)
185 });
186
187 let mut scheduled = HashSet::new();
188 let _in_degree = self.compute_in_degrees(graph, &dependencies);
189
190 while !ready.is_empty() {
191 let mut stage = Vec::new();
192 let mut stage_cost: f64 = 0.0;
193
194 for &node_idx in &ready {
196 let cost = self.get_operation_cost(node_idx);
197 stage.push(node_idx);
198 stage_cost = stage_cost.max(cost);
199 scheduled.insert(node_idx);
200 schedule.execution_order.push(node_idx);
201 }
202
203 schedule.parallel_stages.push(stage);
204 schedule.stage_costs.push(stage_cost);
205 schedule.total_cost += stage_cost;
206
207 ready.clear();
209 for (node_idx, deps) in &dependencies {
210 if scheduled.contains(node_idx) {
211 continue;
212 }
213
214 let all_deps_scheduled = deps.iter().all(|&dep| scheduled.contains(&dep));
215 if all_deps_scheduled {
216 ready.push(*node_idx);
217 }
218 }
219
220 ready.sort_by(|&a, &b| {
222 let cost_a = self.get_operation_cost(a);
223 let cost_b = self.get_operation_cost(b);
224 cost_b
225 .partial_cmp(&cost_a)
226 .unwrap_or(std::cmp::Ordering::Equal)
227 });
228 }
229
230 Ok(schedule)
231 }
232
233 fn schedule_min_memory(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
235 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeMemory);
236
237 let dependencies = self.build_dependencies(graph);
239 let tensor_lifetimes = self.compute_tensor_lifetimes(graph);
240
241 #[allow(clippy::unnecessary_map_or)]
242 let mut ready: Vec<usize> = (0..graph.nodes.len())
243 .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
244 .collect();
245
246 let mut scheduled = HashSet::new();
247
248 while !ready.is_empty() {
249 let best_idx = ready
251 .iter()
252 .max_by_key(|&&idx| self.estimate_memory_freed(graph, idx, &tensor_lifetimes))
253 .copied()
254 .expect("ready list is non-empty at this point in the loop");
255
256 ready.retain(|&idx| idx != best_idx);
257
258 schedule.execution_order.push(best_idx);
259 schedule.parallel_stages.push(vec![best_idx]);
260 let cost = self.get_operation_cost(best_idx);
261 schedule.stage_costs.push(cost);
262 schedule.total_cost += cost;
263 scheduled.insert(best_idx);
264
265 for (node_idx, deps) in &dependencies {
267 if scheduled.contains(node_idx) || ready.contains(node_idx) {
268 continue;
269 }
270
271 if deps.iter().all(|&dep| scheduled.contains(&dep)) {
272 ready.push(*node_idx);
273 }
274 }
275 }
276
277 Ok(schedule)
278 }
279
280 fn schedule_balanced(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
282 let latency_schedule = self.schedule_min_latency(graph)?;
284 let _memory_schedule = self.schedule_min_memory(graph)?;
285
286 Ok(latency_schedule)
289 }
290
291 fn schedule_pipeline(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
293 let mut schedule = ExecutionSchedule::new(SchedulingObjective::Pipeline);
294
295 let stages = self.partition_for_pipeline(graph)?;
297
298 for stage_nodes in stages {
299 let stage_cost = stage_nodes
300 .iter()
301 .map(|&idx| self.get_operation_cost(idx))
302 .sum();
303
304 schedule.parallel_stages.push(stage_nodes.clone());
305 schedule.stage_costs.push(stage_cost);
306 schedule.total_cost = schedule.total_cost.max(stage_cost);
307
308 for &node in &stage_nodes {
309 schedule.execution_order.push(node);
310 }
311 }
312
313 Ok(schedule)
314 }
315
316 fn build_dependencies(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
318 let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
319 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
320
321 for (node_idx, node) in graph.nodes.iter().enumerate() {
323 for &output_idx in &node.outputs {
324 tensor_producer.insert(output_idx, node_idx);
325 }
326 }
327
328 for (node_idx, node) in graph.nodes.iter().enumerate() {
330 let mut deps = Vec::new();
331 for &input_idx in &node.inputs {
332 if let Some(&producer) = tensor_producer.get(&input_idx) {
333 if producer != node_idx {
334 deps.push(producer);
335 }
336 }
337 }
338 dependencies.insert(node_idx, deps);
339 }
340
341 dependencies
342 }
343
344 fn compute_start_times(
346 &self,
347 graph: &EinsumGraph,
348 dependencies: &HashMap<usize, Vec<usize>>,
349 ) -> Vec<f64> {
350 let mut start_times = vec![0.0; graph.nodes.len()];
351 let mut visited = HashSet::new();
352 let mut queue = VecDeque::new();
353
354 for (node_idx, deps) in dependencies {
356 if deps.is_empty() {
357 queue.push_back(*node_idx);
358 }
359 }
360
361 while let Some(node_idx) = queue.pop_front() {
362 if visited.contains(&node_idx) {
363 continue;
364 }
365 visited.insert(node_idx);
366
367 let deps = dependencies
369 .get(&node_idx)
370 .map(|v| v.as_slice())
371 .unwrap_or(&[]);
372 let max_dep_finish = deps
373 .iter()
374 .map(|&dep_idx| start_times[dep_idx] + self.get_operation_cost(dep_idx))
375 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
376 .unwrap_or(0.0);
377
378 start_times[node_idx] = max_dep_finish;
379
380 for (succ_idx, succ_deps) in dependencies {
382 if succ_deps.contains(&node_idx) && !visited.contains(succ_idx) {
383 queue.push_back(*succ_idx);
384 }
385 }
386 }
387
388 start_times
389 }
390
391 fn compute_in_degrees(
393 &self,
394 graph: &EinsumGraph,
395 dependencies: &HashMap<usize, Vec<usize>>,
396 ) -> Vec<usize> {
397 let mut in_degree = vec![0; graph.nodes.len()];
398 for (node_idx, deps) in dependencies {
399 in_degree[*node_idx] = deps.len();
400 }
401 in_degree
402 }
403
404 fn compute_tensor_lifetimes(&self, graph: &EinsumGraph) -> HashMap<usize, (usize, usize)> {
406 let mut lifetimes = HashMap::new();
407
408 for (node_idx, node) in graph.nodes.iter().enumerate() {
409 for &tensor_idx in &node.inputs {
410 let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
411 entry.0 = entry.0.min(node_idx);
412 entry.1 = entry.1.max(node_idx);
413 }
414 for &tensor_idx in &node.outputs {
415 let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
416 entry.0 = entry.0.min(node_idx);
417 entry.1 = entry.1.max(node_idx);
418 }
419 }
420
421 lifetimes
422 }
423
424 fn estimate_memory_freed(
426 &self,
427 graph: &EinsumGraph,
428 node_idx: usize,
429 lifetimes: &HashMap<usize, (usize, usize)>,
430 ) -> usize {
431 let node = &graph.nodes[node_idx];
432 let mut freed = 0;
433
434 for &input_tensor in &node.inputs {
435 if let Some(&(_, last_use)) = lifetimes.get(&input_tensor) {
436 if last_use == node_idx {
437 freed += self.tensor_memory.get(&input_tensor).copied().unwrap_or(1);
438 }
439 }
440 }
441
442 freed
443 }
444
445 fn partition_for_pipeline(&self, graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
447 let total_cost: f64 = (0..graph.nodes.len())
449 .map(|i| self.get_operation_cost(i))
450 .sum();
451
452 let target_stages = 4; let target_cost_per_stage = total_cost / target_stages as f64;
454
455 let dependencies = self.build_dependencies(graph);
456 let topo_order = self.topological_sort(graph, &dependencies);
457
458 let mut stages = Vec::new();
459 let mut current_stage = Vec::new();
460 let mut current_cost = 0.0;
461
462 for &node_idx in &topo_order {
463 let cost = self.get_operation_cost(node_idx);
464 current_stage.push(node_idx);
465 current_cost += cost;
466
467 if current_cost >= target_cost_per_stage {
468 stages.push(current_stage.clone());
469 current_stage.clear();
470 current_cost = 0.0;
471 }
472 }
473
474 if !current_stage.is_empty() {
475 stages.push(current_stage);
476 }
477
478 Ok(stages)
479 }
480
481 fn topological_sort(
483 &self,
484 graph: &EinsumGraph,
485 dependencies: &HashMap<usize, Vec<usize>>,
486 ) -> Vec<usize> {
487 let mut result = Vec::new();
488 let mut visited = HashSet::new();
489 let mut in_degree = self.compute_in_degrees(graph, dependencies);
490
491 let mut queue: VecDeque<usize> = (0..graph.nodes.len())
492 .filter(|&i| in_degree[i] == 0)
493 .collect();
494
495 while let Some(node_idx) = queue.pop_front() {
496 if visited.contains(&node_idx) {
497 continue;
498 }
499 visited.insert(node_idx);
500 result.push(node_idx);
501
502 for (succ_idx, deps) in dependencies {
504 if deps.contains(&node_idx) {
505 in_degree[*succ_idx] = in_degree[*succ_idx].saturating_sub(1);
506 if in_degree[*succ_idx] == 0 {
507 queue.push_back(*succ_idx);
508 }
509 }
510 }
511 }
512
513 result
514 }
515
516 fn get_operation_cost(&self, node_idx: usize) -> f64 {
518 self.operation_costs.get(&node_idx).copied().unwrap_or(1.0)
519 }
520}
521
522impl Default for GraphScheduler {
523 fn default() -> Self {
524 Self::new()
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::graph::EinsumNode;
532
533 #[test]
534 fn test_execution_schedule_creation() {
535 let schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
536 assert_eq!(schedule.objective, SchedulingObjective::MinimizeLatency);
537 assert_eq!(schedule.num_stages(), 0);
538 }
539
540 #[test]
541 fn test_execution_schedule_stats() {
542 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
543 schedule.parallel_stages.push(vec![0, 1, 2]);
544 schedule.parallel_stages.push(vec![3]);
545
546 assert_eq!(schedule.num_stages(), 2);
547 assert_eq!(schedule.max_parallelism(), 3);
548 assert_eq!(schedule.avg_parallelism(), 2.0);
549 }
550
551 #[test]
552 fn test_scheduler_creation() {
553 let scheduler = GraphScheduler::new();
554 assert!(scheduler.operation_costs.is_empty());
555 }
556
557 #[test]
558 fn test_scheduler_set_costs() {
559 let mut scheduler = GraphScheduler::new();
560 scheduler.set_operation_cost(0, 5.0);
561 scheduler.set_tensor_memory(1, 1024);
562
563 assert_eq!(scheduler.get_operation_cost(0), 5.0);
564 assert_eq!(scheduler.tensor_memory.get(&1), Some(&1024));
565 }
566
567 #[test]
568 fn test_schedule_empty_graph() {
569 let scheduler = GraphScheduler::new();
570 let graph = EinsumGraph::new();
571
572 let schedule = scheduler
573 .schedule(&graph, SchedulingObjective::MinimizeLatency)
574 .expect("unwrap");
575 assert_eq!(schedule.num_stages(), 0);
576 }
577
578 #[test]
579 fn test_schedule_single_node() {
580 let mut scheduler = GraphScheduler::new();
581 let mut graph = EinsumGraph::new();
582
583 let a = graph.add_tensor("A");
584 let b = graph.add_tensor("B");
585 graph
586 .add_node(EinsumNode::elem_unary("relu", a, b))
587 .expect("unwrap");
588
589 scheduler.set_operation_cost(0, 2.0);
590
591 let schedule = scheduler
592 .schedule(&graph, SchedulingObjective::MinimizeLatency)
593 .expect("unwrap");
594 assert_eq!(schedule.execution_order.len(), 1);
595 assert_eq!(schedule.total_cost, 2.0);
596 }
597
598 #[test]
599 fn test_build_dependencies() {
600 let scheduler = GraphScheduler::new();
601 let mut graph = EinsumGraph::new();
602
603 let a = graph.add_tensor("A");
604 let b = graph.add_tensor("B");
605 let c = graph.add_tensor("C");
606
607 graph
608 .add_node(EinsumNode::elem_unary("relu", a, b))
609 .expect("unwrap");
610 graph
611 .add_node(EinsumNode::elem_unary("tanh", b, c))
612 .expect("unwrap");
613
614 let deps = scheduler.build_dependencies(&graph);
615 assert_eq!(deps.get(&0).expect("unwrap").len(), 0);
616 assert_eq!(deps.get(&1).expect("unwrap"), &vec![0]);
617 }
618
619 #[test]
620 fn test_topological_sort() {
621 let scheduler = GraphScheduler::new();
622 let mut graph = EinsumGraph::new();
623
624 let a = graph.add_tensor("A");
625 let b = graph.add_tensor("B");
626 let c = graph.add_tensor("C");
627
628 graph
629 .add_node(EinsumNode::elem_unary("relu", a, b))
630 .expect("unwrap");
631 graph
632 .add_node(EinsumNode::elem_unary("tanh", b, c))
633 .expect("unwrap");
634
635 let deps = scheduler.build_dependencies(&graph);
636 let topo = scheduler.topological_sort(&graph, &deps);
637
638 assert_eq!(topo.len(), 2);
639 assert_eq!(topo[0], 0);
640 assert_eq!(topo[1], 1);
641 }
642
643 #[test]
644 fn test_scheduling_objectives() {
645 assert_eq!(
646 SchedulingObjective::MinimizeLatency,
647 SchedulingObjective::MinimizeLatency
648 );
649 assert_ne!(
650 SchedulingObjective::MinimizeLatency,
651 SchedulingObjective::MaximizeThroughput
652 );
653 }
654}