Skip to main content

tensorlogic_infer/
scheduling.rs

1//! Execution scheduling and optimization for efficient graph execution.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
6
7use crate::capabilities::DeviceType;
8
9/// Execution schedule for a graph
10#[derive(Debug, Clone)]
11pub struct ExecutionSchedule {
12    /// Ordered list of node indices to execute
13    pub execution_order: Vec<usize>,
14    /// Device placement for each node
15    pub device_placement: HashMap<usize, DeviceType>,
16    /// Parallel execution groups (nodes that can run concurrently)
17    pub parallel_groups: Vec<Vec<usize>>,
18    /// Estimated execution cost (arbitrary units)
19    pub estimated_cost: f64,
20}
21
22impl ExecutionSchedule {
23    pub fn new() -> Self {
24        ExecutionSchedule {
25            execution_order: Vec::new(),
26            device_placement: HashMap::new(),
27            parallel_groups: Vec::new(),
28            estimated_cost: 0.0,
29        }
30    }
31
32    pub fn sequential(num_nodes: usize, device: DeviceType) -> Self {
33        let execution_order: Vec<usize> = (0..num_nodes).collect();
34        let device_placement: HashMap<_, _> = (0..num_nodes).map(|i| (i, device)).collect();
35        let parallel_groups: Vec<Vec<usize>> = execution_order.iter().map(|&i| vec![i]).collect();
36
37        ExecutionSchedule {
38            execution_order,
39            device_placement,
40            parallel_groups,
41            estimated_cost: num_nodes as f64,
42        }
43    }
44
45    pub fn len(&self) -> usize {
46        self.execution_order.len()
47    }
48
49    pub fn is_empty(&self) -> bool {
50        self.execution_order.is_empty()
51    }
52
53    pub fn get_device(&self, node_idx: usize) -> Option<DeviceType> {
54        self.device_placement.get(&node_idx).copied()
55    }
56
57    pub fn num_parallel_stages(&self) -> usize {
58        self.parallel_groups.len()
59    }
60}
61
62impl Default for ExecutionSchedule {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68/// Scheduling strategy
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum SchedulingStrategy {
71    /// Execute nodes in topological order
72    Sequential,
73    /// Maximize parallelism (minimize depth)
74    MaximizeParallelism,
75    /// Minimize memory usage (reuse tensors aggressively)
76    MinimizeMemory,
77    /// Balance parallelism and memory
78    Balanced,
79    /// Custom cost-based optimization
80    CostBased,
81}
82
83/// Node cost model for scheduling decisions
84#[derive(Debug, Clone)]
85pub struct NodeCost {
86    pub compute_cost: f64,
87    pub memory_cost: usize,
88    pub communication_cost: f64,
89}
90
91impl NodeCost {
92    pub fn new() -> Self {
93        NodeCost {
94            compute_cost: 1.0,
95            memory_cost: 0,
96            communication_cost: 0.0,
97        }
98    }
99
100    pub fn estimate_from_node(node: &EinsumNode) -> Self {
101        let compute_cost = match &node.op {
102            OpType::Einsum { spec } => {
103                // Estimate based on einsum complexity
104                let num_indices = spec.chars().filter(|c| c.is_alphabetic()).count();
105                (num_indices as f64).powi(2) // Rough O(n²) estimate
106            }
107            OpType::ElemUnary { .. } => 1.0,
108            OpType::ElemBinary { .. } => 1.5,
109            OpType::Reduce { axes, .. } => 2.0 + axes.len() as f64,
110        };
111
112        NodeCost {
113            compute_cost,
114            memory_cost: 1024, // Default 1KB estimate
115            communication_cost: 0.0,
116        }
117    }
118
119    pub fn total_cost(&self) -> f64 {
120        self.compute_cost + self.communication_cost + (self.memory_cost as f64 / 1024.0)
121    }
122}
123
124impl Default for NodeCost {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130/// Graph scheduler
131pub struct Scheduler {
132    strategy: SchedulingStrategy,
133}
134
135impl Scheduler {
136    pub fn new(strategy: SchedulingStrategy) -> Self {
137        Scheduler { strategy }
138    }
139
140    /// Generate execution schedule for a graph
141    pub fn schedule(&self, graph: &EinsumGraph) -> ExecutionSchedule {
142        match self.strategy {
143            SchedulingStrategy::Sequential => self.schedule_sequential(graph),
144            SchedulingStrategy::MaximizeParallelism => self.schedule_parallel(graph),
145            SchedulingStrategy::MinimizeMemory => self.schedule_memory_efficient(graph),
146            SchedulingStrategy::Balanced => self.schedule_balanced(graph),
147            SchedulingStrategy::CostBased => self.schedule_cost_based(graph),
148        }
149    }
150
151    fn schedule_sequential(&self, graph: &EinsumGraph) -> ExecutionSchedule {
152        ExecutionSchedule::sequential(graph.nodes.len(), DeviceType::CPU)
153    }
154
155    fn schedule_parallel(&self, graph: &EinsumGraph) -> ExecutionSchedule {
156        let mut schedule = ExecutionSchedule::new();
157        let num_nodes = graph.nodes.len();
158        let _num_tensors = graph.tensors.len();
159
160        // Build dependency graph
161        let deps = self.build_dependency_graph(graph);
162
163        // Compute levels (maximum distance from input)
164        let levels = self.compute_node_levels(graph, &deps);
165
166        // Group nodes by level for parallel execution
167        let max_level = *levels.values().max().unwrap_or(&0);
168        let mut level_groups: Vec<Vec<usize>> = vec![Vec::new(); max_level + 1];
169
170        for (node_idx, &level) in &levels {
171            level_groups[level].push(*node_idx);
172        }
173
174        // Create execution order (level-by-level)
175        for group in &level_groups {
176            schedule.execution_order.extend(group);
177            if !group.is_empty() {
178                schedule.parallel_groups.push(group.clone());
179            }
180        }
181
182        // Assign all nodes to CPU (default)
183        for i in 0..num_nodes {
184            schedule.device_placement.insert(i, DeviceType::CPU);
185        }
186
187        // Estimate cost as number of levels (critical path length)
188        schedule.estimated_cost = (max_level + 1) as f64;
189
190        schedule
191    }
192
193    fn schedule_memory_efficient(&self, graph: &EinsumGraph) -> ExecutionSchedule {
194        let mut schedule = ExecutionSchedule::new();
195        let num_nodes = graph.nodes.len();
196        let num_tensors = graph.tensors.len();
197
198        // Build dependency graph
199        let deps = self.build_dependency_graph(graph);
200
201        // Greedy scheduling: execute nodes that free the most memory first
202        let mut executed = HashSet::new();
203        let mut ready_queue = VecDeque::new();
204
205        // Find initial ready nodes (no dependencies or all deps satisfied)
206        for node_idx in 0..num_nodes {
207            if self.is_ready(node_idx, &deps, &executed, num_tensors) {
208                ready_queue.push_back(node_idx);
209            }
210        }
211
212        while let Some(node_idx) = ready_queue.pop_front() {
213            if executed.contains(&node_idx) {
214                continue;
215            }
216
217            schedule.execution_order.push(node_idx);
218            schedule.parallel_groups.push(vec![node_idx]);
219            schedule.device_placement.insert(node_idx, DeviceType::CPU);
220            executed.insert(node_idx);
221
222            // Add newly ready nodes
223            for next_idx in 0..num_nodes {
224                if !executed.contains(&next_idx)
225                    && self.is_ready(next_idx, &deps, &executed, num_tensors)
226                {
227                    ready_queue.push_back(next_idx);
228                }
229            }
230        }
231
232        schedule.estimated_cost = num_nodes as f64;
233        schedule
234    }
235
236    fn schedule_balanced(&self, graph: &EinsumGraph) -> ExecutionSchedule {
237        // Compromise between parallelism and memory
238        // Use parallel scheduling but limit group sizes
239        let mut parallel_schedule = self.schedule_parallel(graph);
240
241        // Merge small groups to reduce overhead
242        let mut merged_groups = Vec::new();
243        let mut current_group = Vec::new();
244
245        for group in parallel_schedule.parallel_groups {
246            if group.len() > 4 {
247                // Large group: keep separate
248                if !current_group.is_empty() {
249                    merged_groups.push(current_group.clone());
250                    current_group.clear();
251                }
252                merged_groups.push(group);
253            } else {
254                // Small group: accumulate
255                current_group.extend(group);
256                if current_group.len() >= 4 {
257                    merged_groups.push(current_group.clone());
258                    current_group.clear();
259                }
260            }
261        }
262
263        if !current_group.is_empty() {
264            merged_groups.push(current_group);
265        }
266
267        parallel_schedule.parallel_groups = merged_groups;
268        parallel_schedule.estimated_cost *= 1.2; // Slight overhead from merging
269
270        parallel_schedule
271    }
272
273    fn schedule_cost_based(&self, graph: &EinsumGraph) -> ExecutionSchedule {
274        let mut schedule = ExecutionSchedule::new();
275        let num_nodes = graph.nodes.len();
276
277        // Estimate costs for each node
278        let costs: Vec<NodeCost> = graph
279            .nodes
280            .iter()
281            .map(NodeCost::estimate_from_node)
282            .collect();
283
284        // Build dependency graph
285        let deps = self.build_dependency_graph(graph);
286
287        // Compute critical path costs
288        let critical_costs = self.compute_critical_path_costs(graph, &costs, &deps);
289
290        // Sort by critical path cost (highest first for better parallelism)
291        let mut node_priorities: Vec<(usize, f64)> = critical_costs
292            .iter()
293            .enumerate()
294            .map(|(i, &cost)| (i, cost))
295            .collect();
296        node_priorities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
297
298        // Schedule using priority-based topological sort
299        let mut executed = HashSet::new();
300        let num_tensors = graph.tensors.len();
301
302        while executed.len() < num_nodes {
303            let mut current_wave = Vec::new();
304
305            for &(node_idx, _) in &node_priorities {
306                if executed.contains(&node_idx) {
307                    continue;
308                }
309
310                if self.is_ready(node_idx, &deps, &executed, num_tensors) {
311                    current_wave.push(node_idx);
312                    executed.insert(node_idx);
313                }
314            }
315
316            if current_wave.is_empty() {
317                break; // Avoid infinite loop on cyclic graphs
318            }
319
320            schedule.execution_order.extend(&current_wave);
321            schedule.parallel_groups.push(current_wave);
322        }
323
324        // Assign devices (all CPU for now)
325        for i in 0..num_nodes {
326            schedule.device_placement.insert(i, DeviceType::CPU);
327        }
328
329        // Estimate total cost
330        schedule.estimated_cost = costs.iter().map(|c| c.total_cost()).sum();
331
332        schedule
333    }
334
335    fn build_dependency_graph(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
336        let mut deps: HashMap<usize, Vec<usize>> = HashMap::new();
337
338        // Build a mapping from tensor index to the node that produces it
339        let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
340        for (node_idx, node) in graph.nodes.iter().enumerate() {
341            for &output_idx in &node.outputs {
342                tensor_producers.insert(output_idx, node_idx);
343            }
344        }
345
346        // For each node, find which other nodes it depends on
347        for (node_idx, node) in graph.nodes.iter().enumerate() {
348            let mut node_deps = Vec::new();
349            for &input_idx in &node.inputs {
350                // Check if this tensor is produced by another node
351                if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
352                    node_deps.push(producer_idx);
353                }
354            }
355            deps.insert(node_idx, node_deps);
356        }
357
358        deps
359    }
360
361    fn compute_node_levels(
362        &self,
363        graph: &EinsumGraph,
364        deps: &HashMap<usize, Vec<usize>>,
365    ) -> HashMap<usize, usize> {
366        let mut levels = HashMap::new();
367        let num_nodes = graph.nodes.len();
368
369        // Compute levels iteratively
370        for _ in 0..num_nodes {
371            for node_idx in 0..num_nodes {
372                let max_dep_level = deps
373                    .get(&node_idx)
374                    .map(|d| d.iter().filter_map(|&i| levels.get(&i)).max().copied())
375                    .unwrap_or(None);
376
377                let level = max_dep_level.map(|l| l + 1).unwrap_or(0);
378                levels.insert(node_idx, level);
379            }
380        }
381
382        levels
383    }
384
385    fn compute_critical_path_costs(
386        &self,
387        graph: &EinsumGraph,
388        costs: &[NodeCost],
389        deps: &HashMap<usize, Vec<usize>>,
390    ) -> Vec<f64> {
391        let num_nodes = graph.nodes.len();
392        let mut critical_costs = vec![0.0; num_nodes];
393
394        // Compute critical path costs iteratively (reverse topological order)
395        for _ in 0..num_nodes {
396            for node_idx in (0..num_nodes).rev() {
397                let node_cost = costs[node_idx].total_cost();
398
399                // Find max cost among dependent nodes
400                let max_successor_cost = (0..num_nodes)
401                    .filter(|&i| deps.get(&i).map(|d| d.contains(&node_idx)).unwrap_or(false))
402                    .map(|i| critical_costs[i])
403                    .max_by(|a, b| a.partial_cmp(b).unwrap())
404                    .unwrap_or(0.0);
405
406                critical_costs[node_idx] = node_cost + max_successor_cost;
407            }
408        }
409
410        critical_costs
411    }
412
413    fn is_ready(
414        &self,
415        _node_idx: usize,
416        deps: &HashMap<usize, Vec<usize>>,
417        executed: &HashSet<usize>,
418        _num_tensors: usize,
419    ) -> bool {
420        let node_idx = _node_idx;
421        deps.get(&node_idx)
422            .map(|d| d.iter().all(|&dep| executed.contains(&dep)))
423            .unwrap_or(true)
424    }
425}
426
427impl Default for Scheduler {
428    fn default() -> Self {
429        Self::new(SchedulingStrategy::Balanced)
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    fn create_test_graph() -> EinsumGraph {
438        let mut graph = EinsumGraph::new();
439        graph.tensors.push("x".to_string());
440        graph.tensors.push("y".to_string());
441        graph.tensors.push("t2".to_string()); // Output of node 0
442        graph.tensors.push("t3".to_string()); // Output of node 1
443        graph.tensors.push("t4".to_string()); // Output of node 2
444
445        // Node 0: einsum (depends on tensors 0, 1)
446        graph.nodes.push(EinsumNode {
447            inputs: vec![0, 1],
448            outputs: vec![2],
449            op: OpType::Einsum {
450                spec: "ab,bc->ac".into(),
451            },
452            metadata: None,
453        });
454
455        // Node 1: unary op (depends on node 0)
456        graph.nodes.push(EinsumNode {
457            inputs: vec![2], // Output of node 0
458            outputs: vec![3],
459            op: OpType::ElemUnary { op: "relu".into() },
460            metadata: None,
461        });
462
463        // Node 2: reduce (depends on node 1)
464        graph.nodes.push(EinsumNode {
465            inputs: vec![3], // Output of node 1
466            outputs: vec![4],
467            op: OpType::Reduce {
468                op: "sum".into(),
469                axes: vec![0],
470            },
471            metadata: None,
472        });
473
474        graph
475    }
476
477    #[test]
478    fn test_execution_schedule_creation() {
479        let schedule = ExecutionSchedule::new();
480        assert!(schedule.is_empty());
481        assert_eq!(schedule.num_parallel_stages(), 0);
482    }
483
484    #[test]
485    fn test_sequential_schedule() {
486        let schedule = ExecutionSchedule::sequential(5, DeviceType::CPU);
487        assert_eq!(schedule.len(), 5);
488        assert_eq!(schedule.execution_order, vec![0, 1, 2, 3, 4]);
489        assert_eq!(schedule.num_parallel_stages(), 5);
490
491        for i in 0..5 {
492            assert_eq!(schedule.get_device(i), Some(DeviceType::CPU));
493        }
494    }
495
496    #[test]
497    fn test_node_cost_estimation() {
498        let node = EinsumNode {
499            inputs: vec![0, 1],
500            outputs: vec![2],
501            op: OpType::Einsum {
502                spec: "ab,bc->ac".into(),
503            },
504            metadata: None,
505        };
506
507        let cost = NodeCost::estimate_from_node(&node);
508        assert!(cost.compute_cost > 0.0);
509        assert!(cost.total_cost() > 0.0);
510    }
511
512    #[test]
513    fn test_scheduler_sequential() {
514        let graph = create_test_graph();
515        let scheduler = Scheduler::new(SchedulingStrategy::Sequential);
516        let schedule = scheduler.schedule(&graph);
517
518        assert_eq!(schedule.len(), 3);
519        assert_eq!(schedule.execution_order, vec![0, 1, 2]);
520    }
521
522    #[test]
523    fn test_scheduler_parallel() {
524        let graph = create_test_graph();
525        let scheduler = Scheduler::new(SchedulingStrategy::MaximizeParallelism);
526        let schedule = scheduler.schedule(&graph);
527
528        assert_eq!(schedule.len(), 3);
529        // Parallel schedule should group independent nodes
530        assert!(schedule.num_parallel_stages() <= 3);
531    }
532
533    #[test]
534    fn test_scheduler_memory_efficient() {
535        let graph = create_test_graph();
536        let scheduler = Scheduler::new(SchedulingStrategy::MinimizeMemory);
537        let schedule = scheduler.schedule(&graph);
538
539        assert_eq!(schedule.len(), 3);
540        // Should execute in topological order
541        assert!(schedule.execution_order.contains(&0));
542        assert!(schedule.execution_order.contains(&1));
543        assert!(schedule.execution_order.contains(&2));
544    }
545
546    #[test]
547    fn test_scheduler_balanced() {
548        let graph = create_test_graph();
549        let scheduler = Scheduler::new(SchedulingStrategy::Balanced);
550        let schedule = scheduler.schedule(&graph);
551
552        assert_eq!(schedule.len(), 3);
553        assert!(schedule.estimated_cost > 0.0);
554    }
555
556    #[test]
557    fn test_scheduler_cost_based() {
558        let graph = create_test_graph();
559        let scheduler = Scheduler::new(SchedulingStrategy::CostBased);
560        let schedule = scheduler.schedule(&graph);
561
562        assert_eq!(schedule.len(), 3);
563        assert!(schedule.estimated_cost > 0.0);
564    }
565
566    #[test]
567    fn test_dependency_graph_building() {
568        let graph = create_test_graph();
569        let scheduler = Scheduler::default();
570        let deps = scheduler.build_dependency_graph(&graph);
571
572        assert_eq!(deps.len(), 3);
573        assert!(deps[&0].is_empty()); // Node 0 has no node dependencies
574        assert_eq!(deps[&1], vec![0]); // Node 1 depends on node 0
575        assert_eq!(deps[&2], vec![1]); // Node 2 depends on node 1
576    }
577
578    #[test]
579    fn test_node_levels() {
580        let graph = create_test_graph();
581        let scheduler = Scheduler::default();
582        let deps = scheduler.build_dependency_graph(&graph);
583        let levels = scheduler.compute_node_levels(&graph, &deps);
584
585        assert_eq!(levels[&0], 0); // Node 0 is at level 0
586        assert_eq!(levels[&1], 1); // Node 1 is at level 1
587        assert_eq!(levels[&2], 2); // Node 2 is at level 2
588    }
589
590    #[test]
591    fn test_scheduling_strategies() {
592        let strategies = vec![
593            SchedulingStrategy::Sequential,
594            SchedulingStrategy::MaximizeParallelism,
595            SchedulingStrategy::MinimizeMemory,
596            SchedulingStrategy::Balanced,
597            SchedulingStrategy::CostBased,
598        ];
599
600        let graph = create_test_graph();
601
602        for strategy in strategies {
603            let scheduler = Scheduler::new(strategy);
604            let schedule = scheduler.schedule(&graph);
605            assert_eq!(schedule.len(), 3, "Strategy {:?} failed", strategy);
606        }
607    }
608}