1use std::collections::{HashMap, HashSet, VecDeque};
8
9use super::EinsumGraph;
10use crate::error::IrError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchedulingObjective {
15 MinimizeLatency,
17 MaximizeThroughput,
19 MinimizeMemory,
21 Balanced,
23 Pipeline,
25}
26
27#[derive(Debug, Clone)]
29pub struct ExecutionSchedule {
30 pub execution_order: Vec<usize>,
32 pub parallel_stages: Vec<Vec<usize>>,
34 pub stage_costs: Vec<f64>,
36 pub total_cost: f64,
38 pub peak_memory: usize,
40 pub objective: SchedulingObjective,
42}
43
44impl ExecutionSchedule {
45 pub fn new(objective: SchedulingObjective) -> Self {
47 Self {
48 execution_order: Vec::new(),
49 parallel_stages: Vec::new(),
50 stage_costs: Vec::new(),
51 total_cost: 0.0,
52 peak_memory: 0,
53 objective,
54 }
55 }
56
57 pub fn num_stages(&self) -> usize {
59 self.parallel_stages.len()
60 }
61
62 pub fn max_parallelism(&self) -> usize {
64 self.parallel_stages
65 .iter()
66 .map(|s| s.len())
67 .max()
68 .unwrap_or(0)
69 }
70
71 pub fn avg_parallelism(&self) -> f64 {
73 if self.parallel_stages.is_empty() {
74 return 0.0;
75 }
76 let total: usize = self.parallel_stages.iter().map(|s| s.len()).sum();
77 total as f64 / self.parallel_stages.len() as f64
78 }
79}
80
81pub struct GraphScheduler {
83 operation_costs: HashMap<usize, f64>,
85 tensor_memory: HashMap<usize, usize>,
87}
88
89impl GraphScheduler {
90 pub fn new() -> Self {
92 Self {
93 operation_costs: HashMap::new(),
94 tensor_memory: HashMap::new(),
95 }
96 }
97
98 pub fn set_operation_cost(&mut self, node_idx: usize, cost: f64) {
100 self.operation_costs.insert(node_idx, cost);
101 }
102
103 pub fn set_tensor_memory(&mut self, tensor_idx: usize, size: usize) {
105 self.tensor_memory.insert(tensor_idx, size);
106 }
107
108 pub fn schedule(
110 &self,
111 graph: &EinsumGraph,
112 objective: SchedulingObjective,
113 ) -> Result<ExecutionSchedule, IrError> {
114 match objective {
115 SchedulingObjective::MinimizeLatency => self.schedule_min_latency(graph),
116 SchedulingObjective::MaximizeThroughput => self.schedule_max_throughput(graph),
117 SchedulingObjective::MinimizeMemory => self.schedule_min_memory(graph),
118 SchedulingObjective::Balanced => self.schedule_balanced(graph),
119 SchedulingObjective::Pipeline => self.schedule_pipeline(graph),
120 }
121 }
122
123 fn schedule_min_latency(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
125 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
126
127 let dependencies = self.build_dependencies(graph);
129
130 let start_times = self.compute_start_times(graph, &dependencies);
132
133 let mut stages: HashMap<usize, Vec<usize>> = HashMap::new();
135 for (node_idx, &start_time) in start_times.iter().enumerate() {
136 stages
137 .entry(start_time as usize)
138 .or_default()
139 .push(node_idx);
140 }
141
142 let mut stage_indices: Vec<_> = stages.keys().copied().collect();
144 stage_indices.sort_unstable();
145
146 for stage_idx in stage_indices {
147 if let Some(nodes) = stages.get(&stage_idx) {
148 let stage_cost = nodes
149 .iter()
150 .map(|&idx| self.get_operation_cost(idx))
151 .max_by(|a, b| a.partial_cmp(b).unwrap())
152 .unwrap_or(0.0);
153
154 schedule.parallel_stages.push(nodes.clone());
155 schedule.stage_costs.push(stage_cost);
156 schedule.total_cost += stage_cost;
157
158 for &node in nodes {
159 schedule.execution_order.push(node);
160 }
161 }
162 }
163
164 Ok(schedule)
165 }
166
167 fn schedule_max_throughput(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
169 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MaximizeThroughput);
170
171 let dependencies = self.build_dependencies(graph);
173 #[allow(clippy::unnecessary_map_or)]
174 let mut ready: Vec<usize> = (0..graph.nodes.len())
175 .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
176 .collect();
177
178 ready.sort_by(|&a, &b| {
180 let cost_a = self.get_operation_cost(a);
181 let cost_b = self.get_operation_cost(b);
182 cost_b.partial_cmp(&cost_a).unwrap()
183 });
184
185 let mut scheduled = HashSet::new();
186 let _in_degree = self.compute_in_degrees(graph, &dependencies);
187
188 while !ready.is_empty() {
189 let mut stage = Vec::new();
190 let mut stage_cost: f64 = 0.0;
191
192 for &node_idx in &ready {
194 let cost = self.get_operation_cost(node_idx);
195 stage.push(node_idx);
196 stage_cost = stage_cost.max(cost);
197 scheduled.insert(node_idx);
198 schedule.execution_order.push(node_idx);
199 }
200
201 schedule.parallel_stages.push(stage);
202 schedule.stage_costs.push(stage_cost);
203 schedule.total_cost += stage_cost;
204
205 ready.clear();
207 for (node_idx, deps) in &dependencies {
208 if scheduled.contains(node_idx) {
209 continue;
210 }
211
212 let all_deps_scheduled = deps.iter().all(|&dep| scheduled.contains(&dep));
213 if all_deps_scheduled {
214 ready.push(*node_idx);
215 }
216 }
217
218 ready.sort_by(|&a, &b| {
220 let cost_a = self.get_operation_cost(a);
221 let cost_b = self.get_operation_cost(b);
222 cost_b.partial_cmp(&cost_a).unwrap()
223 });
224 }
225
226 Ok(schedule)
227 }
228
229 fn schedule_min_memory(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
231 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeMemory);
232
233 let dependencies = self.build_dependencies(graph);
235 let tensor_lifetimes = self.compute_tensor_lifetimes(graph);
236
237 #[allow(clippy::unnecessary_map_or)]
238 let mut ready: Vec<usize> = (0..graph.nodes.len())
239 .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
240 .collect();
241
242 let mut scheduled = HashSet::new();
243
244 while !ready.is_empty() {
245 let best_idx = ready
247 .iter()
248 .max_by_key(|&&idx| self.estimate_memory_freed(graph, idx, &tensor_lifetimes))
249 .copied()
250 .unwrap();
251
252 ready.retain(|&idx| idx != best_idx);
253
254 schedule.execution_order.push(best_idx);
255 schedule.parallel_stages.push(vec![best_idx]);
256 let cost = self.get_operation_cost(best_idx);
257 schedule.stage_costs.push(cost);
258 schedule.total_cost += cost;
259 scheduled.insert(best_idx);
260
261 for (node_idx, deps) in &dependencies {
263 if scheduled.contains(node_idx) || ready.contains(node_idx) {
264 continue;
265 }
266
267 if deps.iter().all(|&dep| scheduled.contains(&dep)) {
268 ready.push(*node_idx);
269 }
270 }
271 }
272
273 Ok(schedule)
274 }
275
276 fn schedule_balanced(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
278 let latency_schedule = self.schedule_min_latency(graph)?;
280 let _memory_schedule = self.schedule_min_memory(graph)?;
281
282 Ok(latency_schedule)
285 }
286
287 fn schedule_pipeline(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
289 let mut schedule = ExecutionSchedule::new(SchedulingObjective::Pipeline);
290
291 let stages = self.partition_for_pipeline(graph)?;
293
294 for stage_nodes in stages {
295 let stage_cost = stage_nodes
296 .iter()
297 .map(|&idx| self.get_operation_cost(idx))
298 .sum();
299
300 schedule.parallel_stages.push(stage_nodes.clone());
301 schedule.stage_costs.push(stage_cost);
302 schedule.total_cost = schedule.total_cost.max(stage_cost);
303
304 for &node in &stage_nodes {
305 schedule.execution_order.push(node);
306 }
307 }
308
309 Ok(schedule)
310 }
311
312 fn build_dependencies(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
314 let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
315 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
316
317 for (node_idx, node) in graph.nodes.iter().enumerate() {
319 for &output_idx in &node.outputs {
320 tensor_producer.insert(output_idx, node_idx);
321 }
322 }
323
324 for (node_idx, node) in graph.nodes.iter().enumerate() {
326 let mut deps = Vec::new();
327 for &input_idx in &node.inputs {
328 if let Some(&producer) = tensor_producer.get(&input_idx) {
329 if producer != node_idx {
330 deps.push(producer);
331 }
332 }
333 }
334 dependencies.insert(node_idx, deps);
335 }
336
337 dependencies
338 }
339
340 fn compute_start_times(
342 &self,
343 graph: &EinsumGraph,
344 dependencies: &HashMap<usize, Vec<usize>>,
345 ) -> Vec<f64> {
346 let mut start_times = vec![0.0; graph.nodes.len()];
347 let mut visited = HashSet::new();
348 let mut queue = VecDeque::new();
349
350 for (node_idx, deps) in dependencies {
352 if deps.is_empty() {
353 queue.push_back(*node_idx);
354 }
355 }
356
357 while let Some(node_idx) = queue.pop_front() {
358 if visited.contains(&node_idx) {
359 continue;
360 }
361 visited.insert(node_idx);
362
363 let deps = dependencies
365 .get(&node_idx)
366 .map(|v| v.as_slice())
367 .unwrap_or(&[]);
368 let max_dep_finish = deps
369 .iter()
370 .map(|&dep_idx| start_times[dep_idx] + self.get_operation_cost(dep_idx))
371 .max_by(|a, b| a.partial_cmp(b).unwrap())
372 .unwrap_or(0.0);
373
374 start_times[node_idx] = max_dep_finish;
375
376 for (succ_idx, succ_deps) in dependencies {
378 if succ_deps.contains(&node_idx) && !visited.contains(succ_idx) {
379 queue.push_back(*succ_idx);
380 }
381 }
382 }
383
384 start_times
385 }
386
387 fn compute_in_degrees(
389 &self,
390 graph: &EinsumGraph,
391 dependencies: &HashMap<usize, Vec<usize>>,
392 ) -> Vec<usize> {
393 let mut in_degree = vec![0; graph.nodes.len()];
394 for (node_idx, deps) in dependencies {
395 in_degree[*node_idx] = deps.len();
396 }
397 in_degree
398 }
399
400 fn compute_tensor_lifetimes(&self, graph: &EinsumGraph) -> HashMap<usize, (usize, usize)> {
402 let mut lifetimes = HashMap::new();
403
404 for (node_idx, node) in graph.nodes.iter().enumerate() {
405 for &tensor_idx in &node.inputs {
406 let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
407 entry.0 = entry.0.min(node_idx);
408 entry.1 = entry.1.max(node_idx);
409 }
410 for &tensor_idx in &node.outputs {
411 let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
412 entry.0 = entry.0.min(node_idx);
413 entry.1 = entry.1.max(node_idx);
414 }
415 }
416
417 lifetimes
418 }
419
420 fn estimate_memory_freed(
422 &self,
423 graph: &EinsumGraph,
424 node_idx: usize,
425 lifetimes: &HashMap<usize, (usize, usize)>,
426 ) -> usize {
427 let node = &graph.nodes[node_idx];
428 let mut freed = 0;
429
430 for &input_tensor in &node.inputs {
431 if let Some(&(_, last_use)) = lifetimes.get(&input_tensor) {
432 if last_use == node_idx {
433 freed += self.tensor_memory.get(&input_tensor).copied().unwrap_or(1);
434 }
435 }
436 }
437
438 freed
439 }
440
441 fn partition_for_pipeline(&self, graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
443 let total_cost: f64 = (0..graph.nodes.len())
445 .map(|i| self.get_operation_cost(i))
446 .sum();
447
448 let target_stages = 4; let target_cost_per_stage = total_cost / target_stages as f64;
450
451 let dependencies = self.build_dependencies(graph);
452 let topo_order = self.topological_sort(graph, &dependencies);
453
454 let mut stages = Vec::new();
455 let mut current_stage = Vec::new();
456 let mut current_cost = 0.0;
457
458 for &node_idx in &topo_order {
459 let cost = self.get_operation_cost(node_idx);
460 current_stage.push(node_idx);
461 current_cost += cost;
462
463 if current_cost >= target_cost_per_stage {
464 stages.push(current_stage.clone());
465 current_stage.clear();
466 current_cost = 0.0;
467 }
468 }
469
470 if !current_stage.is_empty() {
471 stages.push(current_stage);
472 }
473
474 Ok(stages)
475 }
476
477 fn topological_sort(
479 &self,
480 graph: &EinsumGraph,
481 dependencies: &HashMap<usize, Vec<usize>>,
482 ) -> Vec<usize> {
483 let mut result = Vec::new();
484 let mut visited = HashSet::new();
485 let mut in_degree = self.compute_in_degrees(graph, dependencies);
486
487 let mut queue: VecDeque<usize> = (0..graph.nodes.len())
488 .filter(|&i| in_degree[i] == 0)
489 .collect();
490
491 while let Some(node_idx) = queue.pop_front() {
492 if visited.contains(&node_idx) {
493 continue;
494 }
495 visited.insert(node_idx);
496 result.push(node_idx);
497
498 for (succ_idx, deps) in dependencies {
500 if deps.contains(&node_idx) {
501 in_degree[*succ_idx] = in_degree[*succ_idx].saturating_sub(1);
502 if in_degree[*succ_idx] == 0 {
503 queue.push_back(*succ_idx);
504 }
505 }
506 }
507 }
508
509 result
510 }
511
512 fn get_operation_cost(&self, node_idx: usize) -> f64 {
514 self.operation_costs.get(&node_idx).copied().unwrap_or(1.0)
515 }
516}
517
518impl Default for GraphScheduler {
519 fn default() -> Self {
520 Self::new()
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::graph::EinsumNode;
528
529 #[test]
530 fn test_execution_schedule_creation() {
531 let schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
532 assert_eq!(schedule.objective, SchedulingObjective::MinimizeLatency);
533 assert_eq!(schedule.num_stages(), 0);
534 }
535
536 #[test]
537 fn test_execution_schedule_stats() {
538 let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
539 schedule.parallel_stages.push(vec![0, 1, 2]);
540 schedule.parallel_stages.push(vec![3]);
541
542 assert_eq!(schedule.num_stages(), 2);
543 assert_eq!(schedule.max_parallelism(), 3);
544 assert_eq!(schedule.avg_parallelism(), 2.0);
545 }
546
547 #[test]
548 fn test_scheduler_creation() {
549 let scheduler = GraphScheduler::new();
550 assert!(scheduler.operation_costs.is_empty());
551 }
552
553 #[test]
554 fn test_scheduler_set_costs() {
555 let mut scheduler = GraphScheduler::new();
556 scheduler.set_operation_cost(0, 5.0);
557 scheduler.set_tensor_memory(1, 1024);
558
559 assert_eq!(scheduler.get_operation_cost(0), 5.0);
560 assert_eq!(scheduler.tensor_memory.get(&1), Some(&1024));
561 }
562
563 #[test]
564 fn test_schedule_empty_graph() {
565 let scheduler = GraphScheduler::new();
566 let graph = EinsumGraph::new();
567
568 let schedule = scheduler
569 .schedule(&graph, SchedulingObjective::MinimizeLatency)
570 .unwrap();
571 assert_eq!(schedule.num_stages(), 0);
572 }
573
574 #[test]
575 fn test_schedule_single_node() {
576 let mut scheduler = GraphScheduler::new();
577 let mut graph = EinsumGraph::new();
578
579 let a = graph.add_tensor("A");
580 let b = graph.add_tensor("B");
581 graph
582 .add_node(EinsumNode::elem_unary("relu", a, b))
583 .unwrap();
584
585 scheduler.set_operation_cost(0, 2.0);
586
587 let schedule = scheduler
588 .schedule(&graph, SchedulingObjective::MinimizeLatency)
589 .unwrap();
590 assert_eq!(schedule.execution_order.len(), 1);
591 assert_eq!(schedule.total_cost, 2.0);
592 }
593
594 #[test]
595 fn test_build_dependencies() {
596 let scheduler = GraphScheduler::new();
597 let mut graph = EinsumGraph::new();
598
599 let a = graph.add_tensor("A");
600 let b = graph.add_tensor("B");
601 let c = graph.add_tensor("C");
602
603 graph
604 .add_node(EinsumNode::elem_unary("relu", a, b))
605 .unwrap();
606 graph
607 .add_node(EinsumNode::elem_unary("tanh", b, c))
608 .unwrap();
609
610 let deps = scheduler.build_dependencies(&graph);
611 assert_eq!(deps.get(&0).unwrap().len(), 0);
612 assert_eq!(deps.get(&1).unwrap(), &vec![0]);
613 }
614
615 #[test]
616 fn test_topological_sort() {
617 let scheduler = GraphScheduler::new();
618 let mut graph = EinsumGraph::new();
619
620 let a = graph.add_tensor("A");
621 let b = graph.add_tensor("B");
622 let c = graph.add_tensor("C");
623
624 graph
625 .add_node(EinsumNode::elem_unary("relu", a, b))
626 .unwrap();
627 graph
628 .add_node(EinsumNode::elem_unary("tanh", b, c))
629 .unwrap();
630
631 let deps = scheduler.build_dependencies(&graph);
632 let topo = scheduler.topological_sort(&graph, &deps);
633
634 assert_eq!(topo.len(), 2);
635 assert_eq!(topo[0], 0);
636 assert_eq!(topo[1], 1);
637 }
638
639 #[test]
640 fn test_scheduling_objectives() {
641 assert_eq!(
642 SchedulingObjective::MinimizeLatency,
643 SchedulingObjective::MinimizeLatency
644 );
645 assert_ne!(
646 SchedulingObjective::MinimizeLatency,
647 SchedulingObjective::MaximizeThroughput
648 );
649 }
650}