Skip to main content

tensorlogic_ir/graph/
schedule.rs

1//! Advanced scheduling strategies for tensor graph execution.
2//!
3//! This module provides sophisticated scheduling algorithms that optimize
4//! for different objectives: latency, throughput, resource utilization,
5//! and multi-objective trade-offs.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use super::EinsumGraph;
10use crate::error::IrError;
11
12/// Scheduling objective to optimize for
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchedulingObjective {
15    /// Minimize total execution latency (critical path)
16    MinimizeLatency,
17    /// Maximize throughput (operations per unit time)
18    MaximizeThroughput,
19    /// Minimize peak memory usage
20    MinimizeMemory,
21    /// Balance between latency and memory
22    Balanced,
23    /// Optimize for pipeline execution
24    Pipeline,
25}
26
27/// A schedule for executing graph operations
28#[derive(Debug, Clone)]
29pub struct ExecutionSchedule {
30    /// Ordered sequence of operations to execute
31    pub execution_order: Vec<usize>,
32    /// Operations that can execute in parallel at each step
33    pub parallel_stages: Vec<Vec<usize>>,
34    /// Estimated execution time for each stage
35    pub stage_costs: Vec<f64>,
36    /// Total estimated execution time
37    pub total_cost: f64,
38    /// Peak memory usage
39    pub peak_memory: usize,
40    /// Objective used for scheduling
41    pub objective: SchedulingObjective,
42}
43
44impl ExecutionSchedule {
45    /// Create a new execution schedule
46    pub fn new(objective: SchedulingObjective) -> Self {
47        Self {
48            execution_order: Vec::new(),
49            parallel_stages: Vec::new(),
50            stage_costs: Vec::new(),
51            total_cost: 0.0,
52            peak_memory: 0,
53            objective,
54        }
55    }
56
57    /// Get the number of stages in the schedule
58    pub fn num_stages(&self) -> usize {
59        self.parallel_stages.len()
60    }
61
62    /// Get maximum parallelism across all stages
63    pub fn max_parallelism(&self) -> usize {
64        self.parallel_stages
65            .iter()
66            .map(|s| s.len())
67            .max()
68            .unwrap_or(0)
69    }
70
71    /// Get average parallelism
72    pub fn avg_parallelism(&self) -> f64 {
73        if self.parallel_stages.is_empty() {
74            return 0.0;
75        }
76        let total: usize = self.parallel_stages.iter().map(|s| s.len()).sum();
77        total as f64 / self.parallel_stages.len() as f64
78    }
79}
80
81/// Advanced scheduler for computation graphs
82pub struct GraphScheduler {
83    /// Cost model for operations
84    operation_costs: HashMap<usize, f64>,
85    /// Memory usage per tensor
86    tensor_memory: HashMap<usize, usize>,
87}
88
89impl GraphScheduler {
90    /// Create a new scheduler
91    pub fn new() -> Self {
92        Self {
93            operation_costs: HashMap::new(),
94            tensor_memory: HashMap::new(),
95        }
96    }
97
98    /// Set the cost for an operation
99    pub fn set_operation_cost(&mut self, node_idx: usize, cost: f64) {
100        self.operation_costs.insert(node_idx, cost);
101    }
102
103    /// Set memory size for a tensor
104    pub fn set_tensor_memory(&mut self, tensor_idx: usize, size: usize) {
105        self.tensor_memory.insert(tensor_idx, size);
106    }
107
108    /// Generate a schedule optimized for the given objective
109    pub fn schedule(
110        &self,
111        graph: &EinsumGraph,
112        objective: SchedulingObjective,
113    ) -> Result<ExecutionSchedule, IrError> {
114        match objective {
115            SchedulingObjective::MinimizeLatency => self.schedule_min_latency(graph),
116            SchedulingObjective::MaximizeThroughput => self.schedule_max_throughput(graph),
117            SchedulingObjective::MinimizeMemory => self.schedule_min_memory(graph),
118            SchedulingObjective::Balanced => self.schedule_balanced(graph),
119            SchedulingObjective::Pipeline => self.schedule_pipeline(graph),
120        }
121    }
122
123    /// Schedule to minimize latency (critical path)
124    fn schedule_min_latency(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
125        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
126
127        // Build dependency graph
128        let dependencies = self.build_dependencies(graph);
129
130        // Compute earliest start times using critical path analysis
131        let start_times = self.compute_start_times(graph, &dependencies);
132
133        // Group by start time to create parallel stages
134        let mut stages: HashMap<usize, Vec<usize>> = HashMap::new();
135        for (node_idx, &start_time) in start_times.iter().enumerate() {
136            stages
137                .entry(start_time as usize)
138                .or_default()
139                .push(node_idx);
140        }
141
142        // Sort stages and build schedule
143        let mut stage_indices: Vec<_> = stages.keys().copied().collect();
144        stage_indices.sort_unstable();
145
146        for stage_idx in stage_indices {
147            if let Some(nodes) = stages.get(&stage_idx) {
148                let stage_cost = nodes
149                    .iter()
150                    .map(|&idx| self.get_operation_cost(idx))
151                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
152                    .unwrap_or(0.0);
153
154                schedule.parallel_stages.push(nodes.clone());
155                schedule.stage_costs.push(stage_cost);
156                schedule.total_cost += stage_cost;
157
158                for &node in nodes {
159                    schedule.execution_order.push(node);
160                }
161            }
162        }
163
164        Ok(schedule)
165    }
166
167    /// Schedule to maximize throughput
168    fn schedule_max_throughput(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
169        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MaximizeThroughput);
170
171        // Use list scheduling with longest processing time first
172        let dependencies = self.build_dependencies(graph);
173        #[allow(clippy::unnecessary_map_or)]
174        let mut ready: Vec<usize> = (0..graph.nodes.len())
175            .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
176            .collect();
177
178        // Sort by cost (descending) for better load balancing
179        ready.sort_by(|&a, &b| {
180            let cost_a = self.get_operation_cost(a);
181            let cost_b = self.get_operation_cost(b);
182            cost_b
183                .partial_cmp(&cost_a)
184                .unwrap_or(std::cmp::Ordering::Equal)
185        });
186
187        let mut scheduled = HashSet::new();
188        let _in_degree = self.compute_in_degrees(graph, &dependencies);
189
190        while !ready.is_empty() {
191            let mut stage = Vec::new();
192            let mut stage_cost: f64 = 0.0;
193
194            // Schedule all ready operations in this stage
195            for &node_idx in &ready {
196                let cost = self.get_operation_cost(node_idx);
197                stage.push(node_idx);
198                stage_cost = stage_cost.max(cost);
199                scheduled.insert(node_idx);
200                schedule.execution_order.push(node_idx);
201            }
202
203            schedule.parallel_stages.push(stage);
204            schedule.stage_costs.push(stage_cost);
205            schedule.total_cost += stage_cost;
206
207            // Update ready list
208            ready.clear();
209            for (node_idx, deps) in &dependencies {
210                if scheduled.contains(node_idx) {
211                    continue;
212                }
213
214                let all_deps_scheduled = deps.iter().all(|&dep| scheduled.contains(&dep));
215                if all_deps_scheduled {
216                    ready.push(*node_idx);
217                }
218            }
219
220            // Sort by cost again
221            ready.sort_by(|&a, &b| {
222                let cost_a = self.get_operation_cost(a);
223                let cost_b = self.get_operation_cost(b);
224                cost_b
225                    .partial_cmp(&cost_a)
226                    .unwrap_or(std::cmp::Ordering::Equal)
227            });
228        }
229
230        Ok(schedule)
231    }
232
233    /// Schedule to minimize memory usage
234    fn schedule_min_memory(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
235        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeMemory);
236
237        // Use earliest deadline first with memory pressure
238        let dependencies = self.build_dependencies(graph);
239        let tensor_lifetimes = self.compute_tensor_lifetimes(graph);
240
241        #[allow(clippy::unnecessary_map_or)]
242        let mut ready: Vec<usize> = (0..graph.nodes.len())
243            .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
244            .collect();
245
246        let mut scheduled = HashSet::new();
247
248        while !ready.is_empty() {
249            // Choose operation that frees the most memory
250            let best_idx = ready
251                .iter()
252                .max_by_key(|&&idx| self.estimate_memory_freed(graph, idx, &tensor_lifetimes))
253                .copied()
254                .expect("ready list is non-empty at this point in the loop");
255
256            ready.retain(|&idx| idx != best_idx);
257
258            schedule.execution_order.push(best_idx);
259            schedule.parallel_stages.push(vec![best_idx]);
260            let cost = self.get_operation_cost(best_idx);
261            schedule.stage_costs.push(cost);
262            schedule.total_cost += cost;
263            scheduled.insert(best_idx);
264
265            // Update ready list
266            for (node_idx, deps) in &dependencies {
267                if scheduled.contains(node_idx) || ready.contains(node_idx) {
268                    continue;
269                }
270
271                if deps.iter().all(|&dep| scheduled.contains(&dep)) {
272                    ready.push(*node_idx);
273                }
274            }
275        }
276
277        Ok(schedule)
278    }
279
280    /// Schedule with balanced objectives
281    fn schedule_balanced(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
282        // Use a weighted combination of latency and memory objectives
283        let latency_schedule = self.schedule_min_latency(graph)?;
284        let _memory_schedule = self.schedule_min_memory(graph)?;
285
286        // For now, prefer latency schedule with memory awareness
287        // In a full implementation, we would use multi-objective optimization
288        Ok(latency_schedule)
289    }
290
291    /// Schedule for pipeline execution
292    fn schedule_pipeline(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
293        let mut schedule = ExecutionSchedule::new(SchedulingObjective::Pipeline);
294
295        // Partition graph into pipeline stages
296        let stages = self.partition_for_pipeline(graph)?;
297
298        for stage_nodes in stages {
299            let stage_cost = stage_nodes
300                .iter()
301                .map(|&idx| self.get_operation_cost(idx))
302                .sum();
303
304            schedule.parallel_stages.push(stage_nodes.clone());
305            schedule.stage_costs.push(stage_cost);
306            schedule.total_cost = schedule.total_cost.max(stage_cost);
307
308            for &node in &stage_nodes {
309                schedule.execution_order.push(node);
310            }
311        }
312
313        Ok(schedule)
314    }
315
316    /// Build dependency graph
317    fn build_dependencies(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
318        let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
319        let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
320
321        // Map each tensor to its producer
322        for (node_idx, node) in graph.nodes.iter().enumerate() {
323            for &output_idx in &node.outputs {
324                tensor_producer.insert(output_idx, node_idx);
325            }
326        }
327
328        // Build dependencies
329        for (node_idx, node) in graph.nodes.iter().enumerate() {
330            let mut deps = Vec::new();
331            for &input_idx in &node.inputs {
332                if let Some(&producer) = tensor_producer.get(&input_idx) {
333                    if producer != node_idx {
334                        deps.push(producer);
335                    }
336                }
337            }
338            dependencies.insert(node_idx, deps);
339        }
340
341        dependencies
342    }
343
344    /// Compute earliest start times for each operation
345    fn compute_start_times(
346        &self,
347        graph: &EinsumGraph,
348        dependencies: &HashMap<usize, Vec<usize>>,
349    ) -> Vec<f64> {
350        let mut start_times = vec![0.0; graph.nodes.len()];
351        let mut visited = HashSet::new();
352        let mut queue = VecDeque::new();
353
354        // Find roots (nodes with no dependencies)
355        for (node_idx, deps) in dependencies {
356            if deps.is_empty() {
357                queue.push_back(*node_idx);
358            }
359        }
360
361        while let Some(node_idx) = queue.pop_front() {
362            if visited.contains(&node_idx) {
363                continue;
364            }
365            visited.insert(node_idx);
366
367            // Compute start time based on dependencies
368            let deps = dependencies
369                .get(&node_idx)
370                .map(|v| v.as_slice())
371                .unwrap_or(&[]);
372            let max_dep_finish = deps
373                .iter()
374                .map(|&dep_idx| start_times[dep_idx] + self.get_operation_cost(dep_idx))
375                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
376                .unwrap_or(0.0);
377
378            start_times[node_idx] = max_dep_finish;
379
380            // Add successors to queue
381            for (succ_idx, succ_deps) in dependencies {
382                if succ_deps.contains(&node_idx) && !visited.contains(succ_idx) {
383                    queue.push_back(*succ_idx);
384                }
385            }
386        }
387
388        start_times
389    }
390
391    /// Compute in-degrees for topological sort
392    fn compute_in_degrees(
393        &self,
394        graph: &EinsumGraph,
395        dependencies: &HashMap<usize, Vec<usize>>,
396    ) -> Vec<usize> {
397        let mut in_degree = vec![0; graph.nodes.len()];
398        for (node_idx, deps) in dependencies {
399            in_degree[*node_idx] = deps.len();
400        }
401        in_degree
402    }
403
404    /// Compute tensor lifetimes
405    fn compute_tensor_lifetimes(&self, graph: &EinsumGraph) -> HashMap<usize, (usize, usize)> {
406        let mut lifetimes = HashMap::new();
407
408        for (node_idx, node) in graph.nodes.iter().enumerate() {
409            for &tensor_idx in &node.inputs {
410                let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
411                entry.0 = entry.0.min(node_idx);
412                entry.1 = entry.1.max(node_idx);
413            }
414            for &tensor_idx in &node.outputs {
415                let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
416                entry.0 = entry.0.min(node_idx);
417                entry.1 = entry.1.max(node_idx);
418            }
419        }
420
421        lifetimes
422    }
423
424    /// Estimate memory freed by executing an operation
425    fn estimate_memory_freed(
426        &self,
427        graph: &EinsumGraph,
428        node_idx: usize,
429        lifetimes: &HashMap<usize, (usize, usize)>,
430    ) -> usize {
431        let node = &graph.nodes[node_idx];
432        let mut freed = 0;
433
434        for &input_tensor in &node.inputs {
435            if let Some(&(_, last_use)) = lifetimes.get(&input_tensor) {
436                if last_use == node_idx {
437                    freed += self.tensor_memory.get(&input_tensor).copied().unwrap_or(1);
438                }
439            }
440        }
441
442        freed
443    }
444
445    /// Partition graph into pipeline stages
446    fn partition_for_pipeline(&self, graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
447        // Simple partitioning: divide into roughly equal-cost stages
448        let total_cost: f64 = (0..graph.nodes.len())
449            .map(|i| self.get_operation_cost(i))
450            .sum();
451
452        let target_stages = 4; // Default number of pipeline stages
453        let target_cost_per_stage = total_cost / target_stages as f64;
454
455        let dependencies = self.build_dependencies(graph);
456        let topo_order = self.topological_sort(graph, &dependencies);
457
458        let mut stages = Vec::new();
459        let mut current_stage = Vec::new();
460        let mut current_cost = 0.0;
461
462        for &node_idx in &topo_order {
463            let cost = self.get_operation_cost(node_idx);
464            current_stage.push(node_idx);
465            current_cost += cost;
466
467            if current_cost >= target_cost_per_stage {
468                stages.push(current_stage.clone());
469                current_stage.clear();
470                current_cost = 0.0;
471            }
472        }
473
474        if !current_stage.is_empty() {
475            stages.push(current_stage);
476        }
477
478        Ok(stages)
479    }
480
481    /// Topological sort of the graph
482    fn topological_sort(
483        &self,
484        graph: &EinsumGraph,
485        dependencies: &HashMap<usize, Vec<usize>>,
486    ) -> Vec<usize> {
487        let mut result = Vec::new();
488        let mut visited = HashSet::new();
489        let mut in_degree = self.compute_in_degrees(graph, dependencies);
490
491        let mut queue: VecDeque<usize> = (0..graph.nodes.len())
492            .filter(|&i| in_degree[i] == 0)
493            .collect();
494
495        while let Some(node_idx) = queue.pop_front() {
496            if visited.contains(&node_idx) {
497                continue;
498            }
499            visited.insert(node_idx);
500            result.push(node_idx);
501
502            // Update successors
503            for (succ_idx, deps) in dependencies {
504                if deps.contains(&node_idx) {
505                    in_degree[*succ_idx] = in_degree[*succ_idx].saturating_sub(1);
506                    if in_degree[*succ_idx] == 0 {
507                        queue.push_back(*succ_idx);
508                    }
509                }
510            }
511        }
512
513        result
514    }
515
516    /// Get operation cost (with default)
517    fn get_operation_cost(&self, node_idx: usize) -> f64 {
518        self.operation_costs.get(&node_idx).copied().unwrap_or(1.0)
519    }
520}
521
522impl Default for GraphScheduler {
523    fn default() -> Self {
524        Self::new()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::graph::EinsumNode;
532
533    #[test]
534    fn test_execution_schedule_creation() {
535        let schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
536        assert_eq!(schedule.objective, SchedulingObjective::MinimizeLatency);
537        assert_eq!(schedule.num_stages(), 0);
538    }
539
540    #[test]
541    fn test_execution_schedule_stats() {
542        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
543        schedule.parallel_stages.push(vec![0, 1, 2]);
544        schedule.parallel_stages.push(vec![3]);
545
546        assert_eq!(schedule.num_stages(), 2);
547        assert_eq!(schedule.max_parallelism(), 3);
548        assert_eq!(schedule.avg_parallelism(), 2.0);
549    }
550
551    #[test]
552    fn test_scheduler_creation() {
553        let scheduler = GraphScheduler::new();
554        assert!(scheduler.operation_costs.is_empty());
555    }
556
557    #[test]
558    fn test_scheduler_set_costs() {
559        let mut scheduler = GraphScheduler::new();
560        scheduler.set_operation_cost(0, 5.0);
561        scheduler.set_tensor_memory(1, 1024);
562
563        assert_eq!(scheduler.get_operation_cost(0), 5.0);
564        assert_eq!(scheduler.tensor_memory.get(&1), Some(&1024));
565    }
566
567    #[test]
568    fn test_schedule_empty_graph() {
569        let scheduler = GraphScheduler::new();
570        let graph = EinsumGraph::new();
571
572        let schedule = scheduler
573            .schedule(&graph, SchedulingObjective::MinimizeLatency)
574            .expect("unwrap");
575        assert_eq!(schedule.num_stages(), 0);
576    }
577
578    #[test]
579    fn test_schedule_single_node() {
580        let mut scheduler = GraphScheduler::new();
581        let mut graph = EinsumGraph::new();
582
583        let a = graph.add_tensor("A");
584        let b = graph.add_tensor("B");
585        graph
586            .add_node(EinsumNode::elem_unary("relu", a, b))
587            .expect("unwrap");
588
589        scheduler.set_operation_cost(0, 2.0);
590
591        let schedule = scheduler
592            .schedule(&graph, SchedulingObjective::MinimizeLatency)
593            .expect("unwrap");
594        assert_eq!(schedule.execution_order.len(), 1);
595        assert_eq!(schedule.total_cost, 2.0);
596    }
597
598    #[test]
599    fn test_build_dependencies() {
600        let scheduler = GraphScheduler::new();
601        let mut graph = EinsumGraph::new();
602
603        let a = graph.add_tensor("A");
604        let b = graph.add_tensor("B");
605        let c = graph.add_tensor("C");
606
607        graph
608            .add_node(EinsumNode::elem_unary("relu", a, b))
609            .expect("unwrap");
610        graph
611            .add_node(EinsumNode::elem_unary("tanh", b, c))
612            .expect("unwrap");
613
614        let deps = scheduler.build_dependencies(&graph);
615        assert_eq!(deps.get(&0).expect("unwrap").len(), 0);
616        assert_eq!(deps.get(&1).expect("unwrap"), &vec![0]);
617    }
618
619    #[test]
620    fn test_topological_sort() {
621        let scheduler = GraphScheduler::new();
622        let mut graph = EinsumGraph::new();
623
624        let a = graph.add_tensor("A");
625        let b = graph.add_tensor("B");
626        let c = graph.add_tensor("C");
627
628        graph
629            .add_node(EinsumNode::elem_unary("relu", a, b))
630            .expect("unwrap");
631        graph
632            .add_node(EinsumNode::elem_unary("tanh", b, c))
633            .expect("unwrap");
634
635        let deps = scheduler.build_dependencies(&graph);
636        let topo = scheduler.topological_sort(&graph, &deps);
637
638        assert_eq!(topo.len(), 2);
639        assert_eq!(topo[0], 0);
640        assert_eq!(topo[1], 1);
641    }
642
643    #[test]
644    fn test_scheduling_objectives() {
645        assert_eq!(
646            SchedulingObjective::MinimizeLatency,
647            SchedulingObjective::MinimizeLatency
648        );
649        assert_ne!(
650            SchedulingObjective::MinimizeLatency,
651            SchedulingObjective::MaximizeThroughput
652        );
653    }
654}