1use std::collections::{HashMap, HashSet, VecDeque};
4
5use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
6
7use crate::capabilities::DeviceType;
8
9#[derive(Debug, Clone)]
11pub struct ExecutionSchedule {
12 pub execution_order: Vec<usize>,
14 pub device_placement: HashMap<usize, DeviceType>,
16 pub parallel_groups: Vec<Vec<usize>>,
18 pub estimated_cost: f64,
20}
21
22impl ExecutionSchedule {
23 pub fn new() -> Self {
24 ExecutionSchedule {
25 execution_order: Vec::new(),
26 device_placement: HashMap::new(),
27 parallel_groups: Vec::new(),
28 estimated_cost: 0.0,
29 }
30 }
31
32 pub fn sequential(num_nodes: usize, device: DeviceType) -> Self {
33 let execution_order: Vec<usize> = (0..num_nodes).collect();
34 let device_placement: HashMap<_, _> = (0..num_nodes).map(|i| (i, device)).collect();
35 let parallel_groups: Vec<Vec<usize>> = execution_order.iter().map(|&i| vec![i]).collect();
36
37 ExecutionSchedule {
38 execution_order,
39 device_placement,
40 parallel_groups,
41 estimated_cost: num_nodes as f64,
42 }
43 }
44
45 pub fn len(&self) -> usize {
46 self.execution_order.len()
47 }
48
49 pub fn is_empty(&self) -> bool {
50 self.execution_order.is_empty()
51 }
52
53 pub fn get_device(&self, node_idx: usize) -> Option<DeviceType> {
54 self.device_placement.get(&node_idx).copied()
55 }
56
57 pub fn num_parallel_stages(&self) -> usize {
58 self.parallel_groups.len()
59 }
60}
61
62impl Default for ExecutionSchedule {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum SchedulingStrategy {
71 Sequential,
73 MaximizeParallelism,
75 MinimizeMemory,
77 Balanced,
79 CostBased,
81}
82
83#[derive(Debug, Clone)]
85pub struct NodeCost {
86 pub compute_cost: f64,
87 pub memory_cost: usize,
88 pub communication_cost: f64,
89}
90
91impl NodeCost {
92 pub fn new() -> Self {
93 NodeCost {
94 compute_cost: 1.0,
95 memory_cost: 0,
96 communication_cost: 0.0,
97 }
98 }
99
100 pub fn estimate_from_node(node: &EinsumNode) -> Self {
101 let compute_cost = match &node.op {
102 OpType::Einsum { spec } => {
103 let num_indices = spec.chars().filter(|c| c.is_alphabetic()).count();
105 (num_indices as f64).powi(2) }
107 OpType::ElemUnary { .. } => 1.0,
108 OpType::ElemBinary { .. } => 1.5,
109 OpType::Reduce { axes, .. } => 2.0 + axes.len() as f64,
110 };
111
112 NodeCost {
113 compute_cost,
114 memory_cost: 1024, communication_cost: 0.0,
116 }
117 }
118
119 pub fn total_cost(&self) -> f64 {
120 self.compute_cost + self.communication_cost + (self.memory_cost as f64 / 1024.0)
121 }
122}
123
124impl Default for NodeCost {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130pub struct Scheduler {
132 strategy: SchedulingStrategy,
133}
134
135impl Scheduler {
136 pub fn new(strategy: SchedulingStrategy) -> Self {
137 Scheduler { strategy }
138 }
139
140 pub fn schedule(&self, graph: &EinsumGraph) -> ExecutionSchedule {
142 match self.strategy {
143 SchedulingStrategy::Sequential => self.schedule_sequential(graph),
144 SchedulingStrategy::MaximizeParallelism => self.schedule_parallel(graph),
145 SchedulingStrategy::MinimizeMemory => self.schedule_memory_efficient(graph),
146 SchedulingStrategy::Balanced => self.schedule_balanced(graph),
147 SchedulingStrategy::CostBased => self.schedule_cost_based(graph),
148 }
149 }
150
151 fn schedule_sequential(&self, graph: &EinsumGraph) -> ExecutionSchedule {
152 ExecutionSchedule::sequential(graph.nodes.len(), DeviceType::CPU)
153 }
154
155 fn schedule_parallel(&self, graph: &EinsumGraph) -> ExecutionSchedule {
156 let mut schedule = ExecutionSchedule::new();
157 let num_nodes = graph.nodes.len();
158 let _num_tensors = graph.tensors.len();
159
160 let deps = self.build_dependency_graph(graph);
162
163 let levels = self.compute_node_levels(graph, &deps);
165
166 let max_level = *levels.values().max().unwrap_or(&0);
168 let mut level_groups: Vec<Vec<usize>> = vec![Vec::new(); max_level + 1];
169
170 for (node_idx, &level) in &levels {
171 level_groups[level].push(*node_idx);
172 }
173
174 for group in &level_groups {
176 schedule.execution_order.extend(group);
177 if !group.is_empty() {
178 schedule.parallel_groups.push(group.clone());
179 }
180 }
181
182 for i in 0..num_nodes {
184 schedule.device_placement.insert(i, DeviceType::CPU);
185 }
186
187 schedule.estimated_cost = (max_level + 1) as f64;
189
190 schedule
191 }
192
193 fn schedule_memory_efficient(&self, graph: &EinsumGraph) -> ExecutionSchedule {
194 let mut schedule = ExecutionSchedule::new();
195 let num_nodes = graph.nodes.len();
196 let num_tensors = graph.tensors.len();
197
198 let deps = self.build_dependency_graph(graph);
200
201 let mut executed = HashSet::new();
203 let mut ready_queue = VecDeque::new();
204
205 for node_idx in 0..num_nodes {
207 if self.is_ready(node_idx, &deps, &executed, num_tensors) {
208 ready_queue.push_back(node_idx);
209 }
210 }
211
212 while let Some(node_idx) = ready_queue.pop_front() {
213 if executed.contains(&node_idx) {
214 continue;
215 }
216
217 schedule.execution_order.push(node_idx);
218 schedule.parallel_groups.push(vec![node_idx]);
219 schedule.device_placement.insert(node_idx, DeviceType::CPU);
220 executed.insert(node_idx);
221
222 for next_idx in 0..num_nodes {
224 if !executed.contains(&next_idx)
225 && self.is_ready(next_idx, &deps, &executed, num_tensors)
226 {
227 ready_queue.push_back(next_idx);
228 }
229 }
230 }
231
232 schedule.estimated_cost = num_nodes as f64;
233 schedule
234 }
235
236 fn schedule_balanced(&self, graph: &EinsumGraph) -> ExecutionSchedule {
237 let mut parallel_schedule = self.schedule_parallel(graph);
240
241 let mut merged_groups = Vec::new();
243 let mut current_group = Vec::new();
244
245 for group in parallel_schedule.parallel_groups {
246 if group.len() > 4 {
247 if !current_group.is_empty() {
249 merged_groups.push(current_group.clone());
250 current_group.clear();
251 }
252 merged_groups.push(group);
253 } else {
254 current_group.extend(group);
256 if current_group.len() >= 4 {
257 merged_groups.push(current_group.clone());
258 current_group.clear();
259 }
260 }
261 }
262
263 if !current_group.is_empty() {
264 merged_groups.push(current_group);
265 }
266
267 parallel_schedule.parallel_groups = merged_groups;
268 parallel_schedule.estimated_cost *= 1.2; parallel_schedule
271 }
272
273 fn schedule_cost_based(&self, graph: &EinsumGraph) -> ExecutionSchedule {
274 let mut schedule = ExecutionSchedule::new();
275 let num_nodes = graph.nodes.len();
276
277 let costs: Vec<NodeCost> = graph
279 .nodes
280 .iter()
281 .map(NodeCost::estimate_from_node)
282 .collect();
283
284 let deps = self.build_dependency_graph(graph);
286
287 let critical_costs = self.compute_critical_path_costs(graph, &costs, &deps);
289
290 let mut node_priorities: Vec<(usize, f64)> = critical_costs
292 .iter()
293 .enumerate()
294 .map(|(i, &cost)| (i, cost))
295 .collect();
296 node_priorities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
297
298 let mut executed = HashSet::new();
300 let num_tensors = graph.tensors.len();
301
302 while executed.len() < num_nodes {
303 let mut current_wave = Vec::new();
304
305 for &(node_idx, _) in &node_priorities {
306 if executed.contains(&node_idx) {
307 continue;
308 }
309
310 if self.is_ready(node_idx, &deps, &executed, num_tensors) {
311 current_wave.push(node_idx);
312 executed.insert(node_idx);
313 }
314 }
315
316 if current_wave.is_empty() {
317 break; }
319
320 schedule.execution_order.extend(¤t_wave);
321 schedule.parallel_groups.push(current_wave);
322 }
323
324 for i in 0..num_nodes {
326 schedule.device_placement.insert(i, DeviceType::CPU);
327 }
328
329 schedule.estimated_cost = costs.iter().map(|c| c.total_cost()).sum();
331
332 schedule
333 }
334
335 fn build_dependency_graph(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
336 let mut deps: HashMap<usize, Vec<usize>> = HashMap::new();
337
338 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
340 for (node_idx, node) in graph.nodes.iter().enumerate() {
341 for &output_idx in &node.outputs {
342 tensor_producers.insert(output_idx, node_idx);
343 }
344 }
345
346 for (node_idx, node) in graph.nodes.iter().enumerate() {
348 let mut node_deps = Vec::new();
349 for &input_idx in &node.inputs {
350 if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
352 node_deps.push(producer_idx);
353 }
354 }
355 deps.insert(node_idx, node_deps);
356 }
357
358 deps
359 }
360
361 fn compute_node_levels(
362 &self,
363 graph: &EinsumGraph,
364 deps: &HashMap<usize, Vec<usize>>,
365 ) -> HashMap<usize, usize> {
366 let mut levels = HashMap::new();
367 let num_nodes = graph.nodes.len();
368
369 for _ in 0..num_nodes {
371 for node_idx in 0..num_nodes {
372 let max_dep_level = deps
373 .get(&node_idx)
374 .map(|d| d.iter().filter_map(|&i| levels.get(&i)).max().copied())
375 .unwrap_or(None);
376
377 let level = max_dep_level.map(|l| l + 1).unwrap_or(0);
378 levels.insert(node_idx, level);
379 }
380 }
381
382 levels
383 }
384
385 fn compute_critical_path_costs(
386 &self,
387 graph: &EinsumGraph,
388 costs: &[NodeCost],
389 deps: &HashMap<usize, Vec<usize>>,
390 ) -> Vec<f64> {
391 let num_nodes = graph.nodes.len();
392 let mut critical_costs = vec![0.0; num_nodes];
393
394 for _ in 0..num_nodes {
396 for node_idx in (0..num_nodes).rev() {
397 let node_cost = costs[node_idx].total_cost();
398
399 let max_successor_cost = (0..num_nodes)
401 .filter(|&i| deps.get(&i).map(|d| d.contains(&node_idx)).unwrap_or(false))
402 .map(|i| critical_costs[i])
403 .max_by(|a, b| a.partial_cmp(b).unwrap())
404 .unwrap_or(0.0);
405
406 critical_costs[node_idx] = node_cost + max_successor_cost;
407 }
408 }
409
410 critical_costs
411 }
412
413 fn is_ready(
414 &self,
415 _node_idx: usize,
416 deps: &HashMap<usize, Vec<usize>>,
417 executed: &HashSet<usize>,
418 _num_tensors: usize,
419 ) -> bool {
420 let node_idx = _node_idx;
421 deps.get(&node_idx)
422 .map(|d| d.iter().all(|&dep| executed.contains(&dep)))
423 .unwrap_or(true)
424 }
425}
426
427impl Default for Scheduler {
428 fn default() -> Self {
429 Self::new(SchedulingStrategy::Balanced)
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 fn create_test_graph() -> EinsumGraph {
438 let mut graph = EinsumGraph::new();
439 graph.tensors.push("x".to_string());
440 graph.tensors.push("y".to_string());
441 graph.tensors.push("t2".to_string()); graph.tensors.push("t3".to_string()); graph.tensors.push("t4".to_string()); graph.nodes.push(EinsumNode {
447 inputs: vec![0, 1],
448 outputs: vec![2],
449 op: OpType::Einsum {
450 spec: "ab,bc->ac".into(),
451 },
452 metadata: None,
453 });
454
455 graph.nodes.push(EinsumNode {
457 inputs: vec![2], outputs: vec![3],
459 op: OpType::ElemUnary { op: "relu".into() },
460 metadata: None,
461 });
462
463 graph.nodes.push(EinsumNode {
465 inputs: vec![3], outputs: vec![4],
467 op: OpType::Reduce {
468 op: "sum".into(),
469 axes: vec![0],
470 },
471 metadata: None,
472 });
473
474 graph
475 }
476
477 #[test]
478 fn test_execution_schedule_creation() {
479 let schedule = ExecutionSchedule::new();
480 assert!(schedule.is_empty());
481 assert_eq!(schedule.num_parallel_stages(), 0);
482 }
483
484 #[test]
485 fn test_sequential_schedule() {
486 let schedule = ExecutionSchedule::sequential(5, DeviceType::CPU);
487 assert_eq!(schedule.len(), 5);
488 assert_eq!(schedule.execution_order, vec![0, 1, 2, 3, 4]);
489 assert_eq!(schedule.num_parallel_stages(), 5);
490
491 for i in 0..5 {
492 assert_eq!(schedule.get_device(i), Some(DeviceType::CPU));
493 }
494 }
495
496 #[test]
497 fn test_node_cost_estimation() {
498 let node = EinsumNode {
499 inputs: vec![0, 1],
500 outputs: vec![2],
501 op: OpType::Einsum {
502 spec: "ab,bc->ac".into(),
503 },
504 metadata: None,
505 };
506
507 let cost = NodeCost::estimate_from_node(&node);
508 assert!(cost.compute_cost > 0.0);
509 assert!(cost.total_cost() > 0.0);
510 }
511
512 #[test]
513 fn test_scheduler_sequential() {
514 let graph = create_test_graph();
515 let scheduler = Scheduler::new(SchedulingStrategy::Sequential);
516 let schedule = scheduler.schedule(&graph);
517
518 assert_eq!(schedule.len(), 3);
519 assert_eq!(schedule.execution_order, vec![0, 1, 2]);
520 }
521
522 #[test]
523 fn test_scheduler_parallel() {
524 let graph = create_test_graph();
525 let scheduler = Scheduler::new(SchedulingStrategy::MaximizeParallelism);
526 let schedule = scheduler.schedule(&graph);
527
528 assert_eq!(schedule.len(), 3);
529 assert!(schedule.num_parallel_stages() <= 3);
531 }
532
533 #[test]
534 fn test_scheduler_memory_efficient() {
535 let graph = create_test_graph();
536 let scheduler = Scheduler::new(SchedulingStrategy::MinimizeMemory);
537 let schedule = scheduler.schedule(&graph);
538
539 assert_eq!(schedule.len(), 3);
540 assert!(schedule.execution_order.contains(&0));
542 assert!(schedule.execution_order.contains(&1));
543 assert!(schedule.execution_order.contains(&2));
544 }
545
546 #[test]
547 fn test_scheduler_balanced() {
548 let graph = create_test_graph();
549 let scheduler = Scheduler::new(SchedulingStrategy::Balanced);
550 let schedule = scheduler.schedule(&graph);
551
552 assert_eq!(schedule.len(), 3);
553 assert!(schedule.estimated_cost > 0.0);
554 }
555
556 #[test]
557 fn test_scheduler_cost_based() {
558 let graph = create_test_graph();
559 let scheduler = Scheduler::new(SchedulingStrategy::CostBased);
560 let schedule = scheduler.schedule(&graph);
561
562 assert_eq!(schedule.len(), 3);
563 assert!(schedule.estimated_cost > 0.0);
564 }
565
566 #[test]
567 fn test_dependency_graph_building() {
568 let graph = create_test_graph();
569 let scheduler = Scheduler::default();
570 let deps = scheduler.build_dependency_graph(&graph);
571
572 assert_eq!(deps.len(), 3);
573 assert!(deps[&0].is_empty()); assert_eq!(deps[&1], vec![0]); assert_eq!(deps[&2], vec![1]); }
577
578 #[test]
579 fn test_node_levels() {
580 let graph = create_test_graph();
581 let scheduler = Scheduler::default();
582 let deps = scheduler.build_dependency_graph(&graph);
583 let levels = scheduler.compute_node_levels(&graph, &deps);
584
585 assert_eq!(levels[&0], 0); assert_eq!(levels[&1], 1); assert_eq!(levels[&2], 2); }
589
590 #[test]
591 fn test_scheduling_strategies() {
592 let strategies = vec![
593 SchedulingStrategy::Sequential,
594 SchedulingStrategy::MaximizeParallelism,
595 SchedulingStrategy::MinimizeMemory,
596 SchedulingStrategy::Balanced,
597 SchedulingStrategy::CostBased,
598 ];
599
600 let graph = create_test_graph();
601
602 for strategy in strategies {
603 let scheduler = Scheduler::new(strategy);
604 let schedule = scheduler.schedule(&graph);
605 assert_eq!(schedule.len(), 3, "Strategy {:?} failed", strategy);
606 }
607 }
608}