Skip to main content

tensorlogic_ir/graph/
schedule.rs

1//! Advanced scheduling strategies for tensor graph execution.
2//!
3//! This module provides sophisticated scheduling algorithms that optimize
4//! for different objectives: latency, throughput, resource utilization,
5//! and multi-objective trade-offs.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use super::EinsumGraph;
10use crate::error::IrError;
11
12/// Scheduling objective to optimize for
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchedulingObjective {
15    /// Minimize total execution latency (critical path)
16    MinimizeLatency,
17    /// Maximize throughput (operations per unit time)
18    MaximizeThroughput,
19    /// Minimize peak memory usage
20    MinimizeMemory,
21    /// Balance between latency and memory
22    Balanced,
23    /// Optimize for pipeline execution
24    Pipeline,
25}
26
27/// A schedule for executing graph operations
28#[derive(Debug, Clone)]
29pub struct ExecutionSchedule {
30    /// Ordered sequence of operations to execute
31    pub execution_order: Vec<usize>,
32    /// Operations that can execute in parallel at each step
33    pub parallel_stages: Vec<Vec<usize>>,
34    /// Estimated execution time for each stage
35    pub stage_costs: Vec<f64>,
36    /// Total estimated execution time
37    pub total_cost: f64,
38    /// Peak memory usage
39    pub peak_memory: usize,
40    /// Objective used for scheduling
41    pub objective: SchedulingObjective,
42}
43
44impl ExecutionSchedule {
45    /// Create a new execution schedule
46    pub fn new(objective: SchedulingObjective) -> Self {
47        Self {
48            execution_order: Vec::new(),
49            parallel_stages: Vec::new(),
50            stage_costs: Vec::new(),
51            total_cost: 0.0,
52            peak_memory: 0,
53            objective,
54        }
55    }
56
57    /// Get the number of stages in the schedule
58    pub fn num_stages(&self) -> usize {
59        self.parallel_stages.len()
60    }
61
62    /// Get maximum parallelism across all stages
63    pub fn max_parallelism(&self) -> usize {
64        self.parallel_stages
65            .iter()
66            .map(|s| s.len())
67            .max()
68            .unwrap_or(0)
69    }
70
71    /// Get average parallelism
72    pub fn avg_parallelism(&self) -> f64 {
73        if self.parallel_stages.is_empty() {
74            return 0.0;
75        }
76        let total: usize = self.parallel_stages.iter().map(|s| s.len()).sum();
77        total as f64 / self.parallel_stages.len() as f64
78    }
79}
80
81/// Advanced scheduler for computation graphs
82pub struct GraphScheduler {
83    /// Cost model for operations
84    operation_costs: HashMap<usize, f64>,
85    /// Memory usage per tensor
86    tensor_memory: HashMap<usize, usize>,
87}
88
89impl GraphScheduler {
90    /// Create a new scheduler
91    pub fn new() -> Self {
92        Self {
93            operation_costs: HashMap::new(),
94            tensor_memory: HashMap::new(),
95        }
96    }
97
98    /// Set the cost for an operation
99    pub fn set_operation_cost(&mut self, node_idx: usize, cost: f64) {
100        self.operation_costs.insert(node_idx, cost);
101    }
102
103    /// Set memory size for a tensor
104    pub fn set_tensor_memory(&mut self, tensor_idx: usize, size: usize) {
105        self.tensor_memory.insert(tensor_idx, size);
106    }
107
108    /// Generate a schedule optimized for the given objective
109    pub fn schedule(
110        &self,
111        graph: &EinsumGraph,
112        objective: SchedulingObjective,
113    ) -> Result<ExecutionSchedule, IrError> {
114        match objective {
115            SchedulingObjective::MinimizeLatency => self.schedule_min_latency(graph),
116            SchedulingObjective::MaximizeThroughput => self.schedule_max_throughput(graph),
117            SchedulingObjective::MinimizeMemory => self.schedule_min_memory(graph),
118            SchedulingObjective::Balanced => self.schedule_balanced(graph),
119            SchedulingObjective::Pipeline => self.schedule_pipeline(graph),
120        }
121    }
122
123    /// Schedule to minimize latency (critical path)
124    fn schedule_min_latency(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
125        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
126
127        // Build dependency graph
128        let dependencies = self.build_dependencies(graph);
129
130        // Compute earliest start times using critical path analysis
131        let start_times = self.compute_start_times(graph, &dependencies);
132
133        // Group by start time to create parallel stages
134        let mut stages: HashMap<usize, Vec<usize>> = HashMap::new();
135        for (node_idx, &start_time) in start_times.iter().enumerate() {
136            stages
137                .entry(start_time as usize)
138                .or_default()
139                .push(node_idx);
140        }
141
142        // Sort stages and build schedule
143        let mut stage_indices: Vec<_> = stages.keys().copied().collect();
144        stage_indices.sort_unstable();
145
146        for stage_idx in stage_indices {
147            if let Some(nodes) = stages.get(&stage_idx) {
148                let stage_cost = nodes
149                    .iter()
150                    .map(|&idx| self.get_operation_cost(idx))
151                    .max_by(|a, b| a.partial_cmp(b).unwrap())
152                    .unwrap_or(0.0);
153
154                schedule.parallel_stages.push(nodes.clone());
155                schedule.stage_costs.push(stage_cost);
156                schedule.total_cost += stage_cost;
157
158                for &node in nodes {
159                    schedule.execution_order.push(node);
160                }
161            }
162        }
163
164        Ok(schedule)
165    }
166
167    /// Schedule to maximize throughput
168    fn schedule_max_throughput(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
169        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MaximizeThroughput);
170
171        // Use list scheduling with longest processing time first
172        let dependencies = self.build_dependencies(graph);
173        #[allow(clippy::unnecessary_map_or)]
174        let mut ready: Vec<usize> = (0..graph.nodes.len())
175            .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
176            .collect();
177
178        // Sort by cost (descending) for better load balancing
179        ready.sort_by(|&a, &b| {
180            let cost_a = self.get_operation_cost(a);
181            let cost_b = self.get_operation_cost(b);
182            cost_b.partial_cmp(&cost_a).unwrap()
183        });
184
185        let mut scheduled = HashSet::new();
186        let _in_degree = self.compute_in_degrees(graph, &dependencies);
187
188        while !ready.is_empty() {
189            let mut stage = Vec::new();
190            let mut stage_cost: f64 = 0.0;
191
192            // Schedule all ready operations in this stage
193            for &node_idx in &ready {
194                let cost = self.get_operation_cost(node_idx);
195                stage.push(node_idx);
196                stage_cost = stage_cost.max(cost);
197                scheduled.insert(node_idx);
198                schedule.execution_order.push(node_idx);
199            }
200
201            schedule.parallel_stages.push(stage);
202            schedule.stage_costs.push(stage_cost);
203            schedule.total_cost += stage_cost;
204
205            // Update ready list
206            ready.clear();
207            for (node_idx, deps) in &dependencies {
208                if scheduled.contains(node_idx) {
209                    continue;
210                }
211
212                let all_deps_scheduled = deps.iter().all(|&dep| scheduled.contains(&dep));
213                if all_deps_scheduled {
214                    ready.push(*node_idx);
215                }
216            }
217
218            // Sort by cost again
219            ready.sort_by(|&a, &b| {
220                let cost_a = self.get_operation_cost(a);
221                let cost_b = self.get_operation_cost(b);
222                cost_b.partial_cmp(&cost_a).unwrap()
223            });
224        }
225
226        Ok(schedule)
227    }
228
229    /// Schedule to minimize memory usage
230    fn schedule_min_memory(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
231        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeMemory);
232
233        // Use earliest deadline first with memory pressure
234        let dependencies = self.build_dependencies(graph);
235        let tensor_lifetimes = self.compute_tensor_lifetimes(graph);
236
237        #[allow(clippy::unnecessary_map_or)]
238        let mut ready: Vec<usize> = (0..graph.nodes.len())
239            .filter(|&i| dependencies.get(&i).map_or(true, |deps| deps.is_empty()))
240            .collect();
241
242        let mut scheduled = HashSet::new();
243
244        while !ready.is_empty() {
245            // Choose operation that frees the most memory
246            let best_idx = ready
247                .iter()
248                .max_by_key(|&&idx| self.estimate_memory_freed(graph, idx, &tensor_lifetimes))
249                .copied()
250                .unwrap();
251
252            ready.retain(|&idx| idx != best_idx);
253
254            schedule.execution_order.push(best_idx);
255            schedule.parallel_stages.push(vec![best_idx]);
256            let cost = self.get_operation_cost(best_idx);
257            schedule.stage_costs.push(cost);
258            schedule.total_cost += cost;
259            scheduled.insert(best_idx);
260
261            // Update ready list
262            for (node_idx, deps) in &dependencies {
263                if scheduled.contains(node_idx) || ready.contains(node_idx) {
264                    continue;
265                }
266
267                if deps.iter().all(|&dep| scheduled.contains(&dep)) {
268                    ready.push(*node_idx);
269                }
270            }
271        }
272
273        Ok(schedule)
274    }
275
276    /// Schedule with balanced objectives
277    fn schedule_balanced(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
278        // Use a weighted combination of latency and memory objectives
279        let latency_schedule = self.schedule_min_latency(graph)?;
280        let _memory_schedule = self.schedule_min_memory(graph)?;
281
282        // For now, prefer latency schedule with memory awareness
283        // In a full implementation, we would use multi-objective optimization
284        Ok(latency_schedule)
285    }
286
287    /// Schedule for pipeline execution
288    fn schedule_pipeline(&self, graph: &EinsumGraph) -> Result<ExecutionSchedule, IrError> {
289        let mut schedule = ExecutionSchedule::new(SchedulingObjective::Pipeline);
290
291        // Partition graph into pipeline stages
292        let stages = self.partition_for_pipeline(graph)?;
293
294        for stage_nodes in stages {
295            let stage_cost = stage_nodes
296                .iter()
297                .map(|&idx| self.get_operation_cost(idx))
298                .sum();
299
300            schedule.parallel_stages.push(stage_nodes.clone());
301            schedule.stage_costs.push(stage_cost);
302            schedule.total_cost = schedule.total_cost.max(stage_cost);
303
304            for &node in &stage_nodes {
305                schedule.execution_order.push(node);
306            }
307        }
308
309        Ok(schedule)
310    }
311
312    /// Build dependency graph
313    fn build_dependencies(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
314        let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
315        let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
316
317        // Map each tensor to its producer
318        for (node_idx, node) in graph.nodes.iter().enumerate() {
319            for &output_idx in &node.outputs {
320                tensor_producer.insert(output_idx, node_idx);
321            }
322        }
323
324        // Build dependencies
325        for (node_idx, node) in graph.nodes.iter().enumerate() {
326            let mut deps = Vec::new();
327            for &input_idx in &node.inputs {
328                if let Some(&producer) = tensor_producer.get(&input_idx) {
329                    if producer != node_idx {
330                        deps.push(producer);
331                    }
332                }
333            }
334            dependencies.insert(node_idx, deps);
335        }
336
337        dependencies
338    }
339
340    /// Compute earliest start times for each operation
341    fn compute_start_times(
342        &self,
343        graph: &EinsumGraph,
344        dependencies: &HashMap<usize, Vec<usize>>,
345    ) -> Vec<f64> {
346        let mut start_times = vec![0.0; graph.nodes.len()];
347        let mut visited = HashSet::new();
348        let mut queue = VecDeque::new();
349
350        // Find roots (nodes with no dependencies)
351        for (node_idx, deps) in dependencies {
352            if deps.is_empty() {
353                queue.push_back(*node_idx);
354            }
355        }
356
357        while let Some(node_idx) = queue.pop_front() {
358            if visited.contains(&node_idx) {
359                continue;
360            }
361            visited.insert(node_idx);
362
363            // Compute start time based on dependencies
364            let deps = dependencies
365                .get(&node_idx)
366                .map(|v| v.as_slice())
367                .unwrap_or(&[]);
368            let max_dep_finish = deps
369                .iter()
370                .map(|&dep_idx| start_times[dep_idx] + self.get_operation_cost(dep_idx))
371                .max_by(|a, b| a.partial_cmp(b).unwrap())
372                .unwrap_or(0.0);
373
374            start_times[node_idx] = max_dep_finish;
375
376            // Add successors to queue
377            for (succ_idx, succ_deps) in dependencies {
378                if succ_deps.contains(&node_idx) && !visited.contains(succ_idx) {
379                    queue.push_back(*succ_idx);
380                }
381            }
382        }
383
384        start_times
385    }
386
387    /// Compute in-degrees for topological sort
388    fn compute_in_degrees(
389        &self,
390        graph: &EinsumGraph,
391        dependencies: &HashMap<usize, Vec<usize>>,
392    ) -> Vec<usize> {
393        let mut in_degree = vec![0; graph.nodes.len()];
394        for (node_idx, deps) in dependencies {
395            in_degree[*node_idx] = deps.len();
396        }
397        in_degree
398    }
399
400    /// Compute tensor lifetimes
401    fn compute_tensor_lifetimes(&self, graph: &EinsumGraph) -> HashMap<usize, (usize, usize)> {
402        let mut lifetimes = HashMap::new();
403
404        for (node_idx, node) in graph.nodes.iter().enumerate() {
405            for &tensor_idx in &node.inputs {
406                let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
407                entry.0 = entry.0.min(node_idx);
408                entry.1 = entry.1.max(node_idx);
409            }
410            for &tensor_idx in &node.outputs {
411                let entry = lifetimes.entry(tensor_idx).or_insert((node_idx, node_idx));
412                entry.0 = entry.0.min(node_idx);
413                entry.1 = entry.1.max(node_idx);
414            }
415        }
416
417        lifetimes
418    }
419
420    /// Estimate memory freed by executing an operation
421    fn estimate_memory_freed(
422        &self,
423        graph: &EinsumGraph,
424        node_idx: usize,
425        lifetimes: &HashMap<usize, (usize, usize)>,
426    ) -> usize {
427        let node = &graph.nodes[node_idx];
428        let mut freed = 0;
429
430        for &input_tensor in &node.inputs {
431            if let Some(&(_, last_use)) = lifetimes.get(&input_tensor) {
432                if last_use == node_idx {
433                    freed += self.tensor_memory.get(&input_tensor).copied().unwrap_or(1);
434                }
435            }
436        }
437
438        freed
439    }
440
441    /// Partition graph into pipeline stages
442    fn partition_for_pipeline(&self, graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
443        // Simple partitioning: divide into roughly equal-cost stages
444        let total_cost: f64 = (0..graph.nodes.len())
445            .map(|i| self.get_operation_cost(i))
446            .sum();
447
448        let target_stages = 4; // Default number of pipeline stages
449        let target_cost_per_stage = total_cost / target_stages as f64;
450
451        let dependencies = self.build_dependencies(graph);
452        let topo_order = self.topological_sort(graph, &dependencies);
453
454        let mut stages = Vec::new();
455        let mut current_stage = Vec::new();
456        let mut current_cost = 0.0;
457
458        for &node_idx in &topo_order {
459            let cost = self.get_operation_cost(node_idx);
460            current_stage.push(node_idx);
461            current_cost += cost;
462
463            if current_cost >= target_cost_per_stage {
464                stages.push(current_stage.clone());
465                current_stage.clear();
466                current_cost = 0.0;
467            }
468        }
469
470        if !current_stage.is_empty() {
471            stages.push(current_stage);
472        }
473
474        Ok(stages)
475    }
476
477    /// Topological sort of the graph
478    fn topological_sort(
479        &self,
480        graph: &EinsumGraph,
481        dependencies: &HashMap<usize, Vec<usize>>,
482    ) -> Vec<usize> {
483        let mut result = Vec::new();
484        let mut visited = HashSet::new();
485        let mut in_degree = self.compute_in_degrees(graph, dependencies);
486
487        let mut queue: VecDeque<usize> = (0..graph.nodes.len())
488            .filter(|&i| in_degree[i] == 0)
489            .collect();
490
491        while let Some(node_idx) = queue.pop_front() {
492            if visited.contains(&node_idx) {
493                continue;
494            }
495            visited.insert(node_idx);
496            result.push(node_idx);
497
498            // Update successors
499            for (succ_idx, deps) in dependencies {
500                if deps.contains(&node_idx) {
501                    in_degree[*succ_idx] = in_degree[*succ_idx].saturating_sub(1);
502                    if in_degree[*succ_idx] == 0 {
503                        queue.push_back(*succ_idx);
504                    }
505                }
506            }
507        }
508
509        result
510    }
511
512    /// Get operation cost (with default)
513    fn get_operation_cost(&self, node_idx: usize) -> f64 {
514        self.operation_costs.get(&node_idx).copied().unwrap_or(1.0)
515    }
516}
517
518impl Default for GraphScheduler {
519    fn default() -> Self {
520        Self::new()
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::graph::EinsumNode;
528
529    #[test]
530    fn test_execution_schedule_creation() {
531        let schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
532        assert_eq!(schedule.objective, SchedulingObjective::MinimizeLatency);
533        assert_eq!(schedule.num_stages(), 0);
534    }
535
536    #[test]
537    fn test_execution_schedule_stats() {
538        let mut schedule = ExecutionSchedule::new(SchedulingObjective::MinimizeLatency);
539        schedule.parallel_stages.push(vec![0, 1, 2]);
540        schedule.parallel_stages.push(vec![3]);
541
542        assert_eq!(schedule.num_stages(), 2);
543        assert_eq!(schedule.max_parallelism(), 3);
544        assert_eq!(schedule.avg_parallelism(), 2.0);
545    }
546
547    #[test]
548    fn test_scheduler_creation() {
549        let scheduler = GraphScheduler::new();
550        assert!(scheduler.operation_costs.is_empty());
551    }
552
553    #[test]
554    fn test_scheduler_set_costs() {
555        let mut scheduler = GraphScheduler::new();
556        scheduler.set_operation_cost(0, 5.0);
557        scheduler.set_tensor_memory(1, 1024);
558
559        assert_eq!(scheduler.get_operation_cost(0), 5.0);
560        assert_eq!(scheduler.tensor_memory.get(&1), Some(&1024));
561    }
562
563    #[test]
564    fn test_schedule_empty_graph() {
565        let scheduler = GraphScheduler::new();
566        let graph = EinsumGraph::new();
567
568        let schedule = scheduler
569            .schedule(&graph, SchedulingObjective::MinimizeLatency)
570            .unwrap();
571        assert_eq!(schedule.num_stages(), 0);
572    }
573
574    #[test]
575    fn test_schedule_single_node() {
576        let mut scheduler = GraphScheduler::new();
577        let mut graph = EinsumGraph::new();
578
579        let a = graph.add_tensor("A");
580        let b = graph.add_tensor("B");
581        graph
582            .add_node(EinsumNode::elem_unary("relu", a, b))
583            .unwrap();
584
585        scheduler.set_operation_cost(0, 2.0);
586
587        let schedule = scheduler
588            .schedule(&graph, SchedulingObjective::MinimizeLatency)
589            .unwrap();
590        assert_eq!(schedule.execution_order.len(), 1);
591        assert_eq!(schedule.total_cost, 2.0);
592    }
593
594    #[test]
595    fn test_build_dependencies() {
596        let scheduler = GraphScheduler::new();
597        let mut graph = EinsumGraph::new();
598
599        let a = graph.add_tensor("A");
600        let b = graph.add_tensor("B");
601        let c = graph.add_tensor("C");
602
603        graph
604            .add_node(EinsumNode::elem_unary("relu", a, b))
605            .unwrap();
606        graph
607            .add_node(EinsumNode::elem_unary("tanh", b, c))
608            .unwrap();
609
610        let deps = scheduler.build_dependencies(&graph);
611        assert_eq!(deps.get(&0).unwrap().len(), 0);
612        assert_eq!(deps.get(&1).unwrap(), &vec![0]);
613    }
614
615    #[test]
616    fn test_topological_sort() {
617        let scheduler = GraphScheduler::new();
618        let mut graph = EinsumGraph::new();
619
620        let a = graph.add_tensor("A");
621        let b = graph.add_tensor("B");
622        let c = graph.add_tensor("C");
623
624        graph
625            .add_node(EinsumNode::elem_unary("relu", a, b))
626            .unwrap();
627        graph
628            .add_node(EinsumNode::elem_unary("tanh", b, c))
629            .unwrap();
630
631        let deps = scheduler.build_dependencies(&graph);
632        let topo = scheduler.topological_sort(&graph, &deps);
633
634        assert_eq!(topo.len(), 2);
635        assert_eq!(topo[0], 0);
636        assert_eq!(topo[1], 1);
637    }
638
639    #[test]
640    fn test_scheduling_objectives() {
641        assert_eq!(
642            SchedulingObjective::MinimizeLatency,
643            SchedulingObjective::MinimizeLatency
644        );
645        assert_ne!(
646            SchedulingObjective::MinimizeLatency,
647            SchedulingObjective::MaximizeThroughput
648        );
649    }
650}