Skip to main content

tensorlogic_infer/
placement.rs

1//! Device placement and multi-device execution coordination.
2
3use std::collections::HashMap;
4
5use tensorlogic_ir::EinsumGraph;
6
7use crate::capabilities::DeviceType;
8use crate::scheduling::ExecutionSchedule;
9
10/// Device specification with optional id
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Device {
13    pub device_type: DeviceType,
14    pub device_id: usize,
15}
16
17impl Device {
18    pub fn new(device_type: DeviceType, device_id: usize) -> Self {
19        Device {
20            device_type,
21            device_id,
22        }
23    }
24
25    pub fn cpu(id: usize) -> Self {
26        Device::new(DeviceType::CPU, id)
27    }
28
29    pub fn gpu(id: usize) -> Self {
30        Device::new(DeviceType::GPU, id)
31    }
32
33    pub fn default_cpu() -> Self {
34        Device::cpu(0)
35    }
36
37    pub fn as_str(&self) -> String {
38        format!("{}:{}", self.device_type.as_str(), self.device_id)
39    }
40}
41
42/// Placement strategy for multi-device execution
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum PlacementStrategy {
45    /// Place all operations on a single device
46    SingleDevice,
47    /// Round-robin placement across devices
48    RoundRobin,
49    /// Place based on operation cost
50    CostBased,
51    /// Place to minimize data transfer
52    MinimizeTransfer,
53    /// Custom placement via callback
54    Custom,
55}
56
57/// Device placement plan
58#[derive(Debug, Clone)]
59pub struct PlacementPlan {
60    /// Node index -> Device mapping
61    pub node_placement: HashMap<usize, Device>,
62    /// Tensor index -> Device mapping
63    pub tensor_placement: HashMap<usize, Device>,
64    /// Estimated transfer cost
65    pub transfer_cost: f64,
66}
67
68impl PlacementPlan {
69    pub fn new() -> Self {
70        PlacementPlan {
71            node_placement: HashMap::new(),
72            tensor_placement: HashMap::new(),
73            transfer_cost: 0.0,
74        }
75    }
76
77    /// Create a single-device placement plan
78    pub fn single_device(num_nodes: usize, num_tensors: usize, device: Device) -> Self {
79        let mut plan = PlacementPlan::new();
80
81        for i in 0..num_nodes {
82            plan.node_placement.insert(i, device);
83        }
84
85        for i in 0..num_tensors {
86            plan.tensor_placement.insert(i, device);
87        }
88
89        plan
90    }
91
92    /// Get device for a node
93    pub fn get_node_device(&self, node_idx: usize) -> Option<Device> {
94        self.node_placement.get(&node_idx).copied()
95    }
96
97    /// Get device for a tensor
98    pub fn get_tensor_device(&self, tensor_idx: usize) -> Option<Device> {
99        self.tensor_placement.get(&tensor_idx).copied()
100    }
101
102    /// Count number of cross-device transfers
103    pub fn count_transfers(&self, graph: &EinsumGraph) -> usize {
104        let mut transfers = 0;
105
106        // Build a mapping from tensor index to the node that produces it
107        let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
108        for (node_idx, node) in graph.nodes.iter().enumerate() {
109            for &output_idx in &node.outputs {
110                tensor_producers.insert(output_idx, node_idx);
111            }
112        }
113
114        for (node_idx, node) in graph.nodes.iter().enumerate() {
115            let node_device = self.get_node_device(node_idx);
116
117            for &input_idx in &node.inputs {
118                // Determine the device of the input tensor
119                let input_device = if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
120                    // Tensor is produced by another node
121                    self.get_node_device(producer_idx)
122                } else {
123                    // Tensor is an input tensor
124                    self.get_tensor_device(input_idx)
125                };
126
127                if node_device != input_device {
128                    transfers += 1;
129                }
130            }
131        }
132
133        transfers
134    }
135
136    /// Get list of all devices used in this plan
137    pub fn devices(&self) -> Vec<Device> {
138        let mut devices: Vec<_> = self.node_placement.values().copied().collect();
139        devices.sort_by(|a, b| {
140            a.device_id
141                .cmp(&b.device_id)
142                .then_with(|| format!("{:?}", a.device_type).cmp(&format!("{:?}", b.device_type)))
143        });
144        devices.dedup();
145        devices
146    }
147
148    /// Summary of the placement plan
149    pub fn summary(&self) -> String {
150        let devices = self.devices();
151        format!(
152            "Placement Plan:\n\
153             - Nodes: {}\n\
154             - Tensors: {}\n\
155             - Devices: {} ({:?})\n\
156             - Transfer cost: {:.2}",
157            self.node_placement.len(),
158            self.tensor_placement.len(),
159            devices.len(),
160            devices.iter().map(|d| d.as_str()).collect::<Vec<_>>(),
161            self.transfer_cost
162        )
163    }
164}
165
166impl Default for PlacementPlan {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172/// Device placement optimizer
173pub struct PlacementOptimizer {
174    strategy: PlacementStrategy,
175    available_devices: Vec<Device>,
176}
177
178impl PlacementOptimizer {
179    pub fn new(strategy: PlacementStrategy, available_devices: Vec<Device>) -> Self {
180        PlacementOptimizer {
181            strategy,
182            available_devices,
183        }
184    }
185
186    /// Create optimizer for single device
187    pub fn single_device(device: Device) -> Self {
188        PlacementOptimizer {
189            strategy: PlacementStrategy::SingleDevice,
190            available_devices: vec![device],
191        }
192    }
193
194    /// Compute placement plan for a graph
195    pub fn place(&self, graph: &EinsumGraph) -> PlacementPlan {
196        match self.strategy {
197            PlacementStrategy::SingleDevice => self.place_single_device(graph),
198            PlacementStrategy::RoundRobin => self.place_round_robin(graph),
199            PlacementStrategy::CostBased => self.place_cost_based(graph),
200            PlacementStrategy::MinimizeTransfer => self.place_minimize_transfer(graph),
201            PlacementStrategy::Custom => self.place_single_device(graph), // Fallback
202        }
203    }
204
205    /// Compute placement with an execution schedule
206    pub fn place_with_schedule(
207        &self,
208        graph: &EinsumGraph,
209        schedule: &ExecutionSchedule,
210    ) -> PlacementPlan {
211        let mut plan = self.place(graph);
212
213        // Use schedule's device placement if available
214        for (node_idx, device_type) in &schedule.device_placement {
215            if let Some(device) = self.find_device(*device_type) {
216                plan.node_placement.insert(*node_idx, device);
217            }
218        }
219
220        // Recompute transfer cost
221        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
222
223        plan
224    }
225
226    fn place_single_device(&self, graph: &EinsumGraph) -> PlacementPlan {
227        let device = self
228            .available_devices
229            .first()
230            .copied()
231            .unwrap_or(Device::default_cpu());
232        PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), device)
233    }
234
235    fn place_round_robin(&self, graph: &EinsumGraph) -> PlacementPlan {
236        let mut plan = PlacementPlan::new();
237
238        if self.available_devices.is_empty() {
239            return plan;
240        }
241
242        // Place input tensors
243        for (idx, _) in graph.tensors.iter().enumerate() {
244            let device = self.available_devices[idx % self.available_devices.len()];
245            plan.tensor_placement.insert(idx, device);
246        }
247
248        // Place nodes
249        for (idx, _) in graph.nodes.iter().enumerate() {
250            let device = self.available_devices[idx % self.available_devices.len()];
251            plan.node_placement.insert(idx, device);
252        }
253
254        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
255        plan
256    }
257
258    fn place_cost_based(&self, graph: &EinsumGraph) -> PlacementPlan {
259        use crate::scheduling::NodeCost;
260
261        let mut plan = PlacementPlan::new();
262
263        if self.available_devices.is_empty() {
264            return plan;
265        }
266
267        // Compute node costs
268        let costs: Vec<f64> = graph
269            .nodes
270            .iter()
271            .map(|node| NodeCost::estimate_from_node(node).total_cost())
272            .collect();
273
274        // Use greedy assignment to balance load across devices
275        let mut device_loads = vec![0.0; self.available_devices.len()];
276
277        // Sort nodes by cost (descending)
278        let mut node_order: Vec<_> = costs.iter().enumerate().collect();
279        node_order.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
280
281        // Assign each node to the least loaded device
282        for (node_idx, &cost) in node_order {
283            let min_device_idx = device_loads
284                .iter()
285                .enumerate()
286                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
287                .map(|(idx, _)| idx)
288                .unwrap_or(0);
289
290            device_loads[min_device_idx] += cost;
291            plan.node_placement
292                .insert(node_idx, self.available_devices[min_device_idx]);
293        }
294
295        // Place tensors on same device as their first consumer
296        let _num_tensors = graph.tensors.len();
297        for (tensor_idx, _) in graph.tensors.iter().enumerate() {
298            // Find first node that uses this tensor
299            let consumer_device = graph
300                .nodes
301                .iter()
302                .enumerate()
303                .find(|(_, node)| node.inputs.contains(&tensor_idx))
304                .and_then(|(node_idx, _)| plan.node_placement.get(&node_idx))
305                .copied()
306                .unwrap_or(self.available_devices[0]);
307
308            plan.tensor_placement.insert(tensor_idx, consumer_device);
309        }
310
311        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
312        plan
313    }
314
315    fn place_minimize_transfer(&self, graph: &EinsumGraph) -> PlacementPlan {
316        let mut plan = PlacementPlan::new();
317
318        if self.available_devices.is_empty() {
319            return plan;
320        }
321
322        // Start with single device placement
323        let default_device = self.available_devices[0];
324        plan = PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), default_device);
325
326        // Iteratively try to reduce transfers by moving nodes
327        let mut improved = true;
328        let max_iterations = 10;
329        let mut iteration = 0;
330
331        while improved && iteration < max_iterations {
332            improved = false;
333            iteration += 1;
334
335            let current_transfers = plan.count_transfers(graph);
336
337            for node_idx in 0..graph.nodes.len() {
338                let current_device = plan
339                    .get_node_device(node_idx)
340                    .expect("node was placed before optimization loop");
341
342                // Try each alternative device
343                for &candidate_device in &self.available_devices {
344                    if candidate_device == current_device {
345                        continue;
346                    }
347
348                    // Temporarily change placement
349                    plan.node_placement.insert(node_idx, candidate_device);
350                    let new_transfers = plan.count_transfers(graph);
351
352                    if new_transfers < current_transfers {
353                        // Keep the change
354                        improved = true;
355                        break;
356                    } else {
357                        // Revert
358                        plan.node_placement.insert(node_idx, current_device);
359                    }
360                }
361            }
362        }
363
364        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
365        plan
366    }
367
368    fn estimate_transfer_cost(&self, graph: &EinsumGraph, plan: &PlacementPlan) -> f64 {
369        // Simple cost model: 1.0 per transfer
370        plan.count_transfers(graph) as f64
371    }
372
373    fn find_device(&self, device_type: DeviceType) -> Option<Device> {
374        self.available_devices
375            .iter()
376            .find(|d| d.device_type == device_type)
377            .copied()
378    }
379
380    pub fn strategy(&self) -> PlacementStrategy {
381        self.strategy
382    }
383
384    pub fn available_devices(&self) -> &[Device] {
385        &self.available_devices
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use tensorlogic_ir::{EinsumNode, OpType};
393
394    fn create_test_graph() -> EinsumGraph {
395        let mut graph = EinsumGraph::new();
396        graph.tensors.push("x".to_string());
397        graph.tensors.push("y".to_string());
398        graph.tensors.push("t2".to_string()); // Output of node 0
399        graph.tensors.push("t3".to_string()); // Output of node 1
400
401        graph.nodes.push(EinsumNode {
402            inputs: vec![0, 1],
403            outputs: vec![2],
404            op: OpType::Einsum {
405                spec: "ab,bc->ac".into(),
406            },
407            metadata: None,
408        });
409
410        graph.nodes.push(EinsumNode {
411            inputs: vec![2],
412            outputs: vec![3],
413            op: OpType::ElemUnary { op: "relu".into() },
414            metadata: None,
415        });
416
417        graph
418    }
419
420    #[test]
421    fn test_device_creation() {
422        let cpu = Device::cpu(0);
423        assert_eq!(cpu.device_type, DeviceType::CPU);
424        assert_eq!(cpu.device_id, 0);
425        assert_eq!(cpu.as_str(), "CPU:0");
426
427        let gpu = Device::gpu(1);
428        assert_eq!(gpu.device_type, DeviceType::GPU);
429        assert_eq!(gpu.device_id, 1);
430        assert_eq!(gpu.as_str(), "GPU:1");
431    }
432
433    #[test]
434    fn test_placement_plan_single_device() {
435        let device = Device::cpu(0);
436        let plan = PlacementPlan::single_device(3, 2, device);
437
438        assert_eq!(plan.node_placement.len(), 3);
439        assert_eq!(plan.tensor_placement.len(), 2);
440        assert_eq!(plan.get_node_device(0), Some(device));
441        assert_eq!(plan.get_tensor_device(0), Some(device));
442    }
443
444    #[test]
445    fn test_placement_plan_devices() {
446        let mut plan = PlacementPlan::new();
447        plan.node_placement.insert(0, Device::cpu(0));
448        plan.node_placement.insert(1, Device::gpu(0));
449        plan.node_placement.insert(2, Device::gpu(1));
450
451        let devices = plan.devices();
452        assert!(devices.len() >= 2); // At least CPU and GPU
453    }
454
455    #[test]
456    fn test_single_device_placement() {
457        let graph = create_test_graph();
458        let optimizer = PlacementOptimizer::single_device(Device::cpu(0));
459        let plan = optimizer.place(&graph);
460
461        assert_eq!(plan.node_placement.len(), 2);
462        assert_eq!(plan.count_transfers(&graph), 0); // No transfers on single device
463    }
464
465    #[test]
466    fn test_round_robin_placement() {
467        let graph = create_test_graph();
468        let devices = vec![Device::cpu(0), Device::cpu(1)];
469        let optimizer = PlacementOptimizer::new(PlacementStrategy::RoundRobin, devices);
470        let plan = optimizer.place(&graph);
471
472        assert_eq!(plan.node_placement.len(), 2);
473        // Different nodes should be on different devices
474        let dev0 = plan.get_node_device(0);
475        let dev1 = plan.get_node_device(1);
476        assert_ne!(dev0, dev1);
477    }
478
479    #[test]
480    fn test_cost_based_placement() {
481        let graph = create_test_graph();
482        let devices = vec![Device::cpu(0), Device::gpu(0)];
483        let optimizer = PlacementOptimizer::new(PlacementStrategy::CostBased, devices);
484        let plan = optimizer.place(&graph);
485
486        assert_eq!(plan.node_placement.len(), 2);
487        assert!(plan.transfer_cost >= 0.0);
488    }
489
490    #[test]
491    fn test_minimize_transfer_placement() {
492        let graph = create_test_graph();
493        let devices = vec![Device::cpu(0), Device::cpu(1)];
494        let optimizer = PlacementOptimizer::new(PlacementStrategy::MinimizeTransfer, devices);
495        let plan = optimizer.place(&graph);
496
497        assert_eq!(plan.node_placement.len(), 2);
498        // Should minimize transfers
499        let single_device_plan = PlacementOptimizer::single_device(Device::cpu(0)).place(&graph);
500        assert!(plan.count_transfers(&graph) <= single_device_plan.count_transfers(&graph) + 2);
501    }
502
503    #[test]
504    fn test_transfer_counting() {
505        let graph = create_test_graph();
506
507        // Single device: no transfers
508        let plan1 = PlacementPlan::single_device(2, 4, Device::cpu(0));
509        assert_eq!(plan1.count_transfers(&graph), 0);
510
511        // Different devices: some transfers
512        let mut plan2 = PlacementPlan::new();
513        plan2.node_placement.insert(0, Device::cpu(0));
514        plan2.node_placement.insert(1, Device::gpu(0));
515        plan2.tensor_placement.insert(0, Device::cpu(0));
516        plan2.tensor_placement.insert(1, Device::cpu(0));
517        plan2.tensor_placement.insert(2, Device::cpu(0)); // Output of node 0
518        plan2.tensor_placement.insert(3, Device::gpu(0)); // Output of node 1
519
520        let transfers = plan2.count_transfers(&graph);
521        assert!(transfers > 0); // Should have at least one transfer
522    }
523
524    #[test]
525    fn test_placement_strategies() {
526        let strategies = vec![
527            PlacementStrategy::SingleDevice,
528            PlacementStrategy::RoundRobin,
529            PlacementStrategy::CostBased,
530            PlacementStrategy::MinimizeTransfer,
531        ];
532
533        let graph = create_test_graph();
534        let devices = vec![Device::cpu(0), Device::cpu(1)];
535
536        for strategy in strategies {
537            let optimizer = PlacementOptimizer::new(strategy, devices.clone());
538            let plan = optimizer.place(&graph);
539            assert!(
540                !plan.node_placement.is_empty(),
541                "Strategy {:?} failed",
542                strategy
543            );
544        }
545    }
546
547    #[test]
548    fn test_placement_summary() {
549        let plan = PlacementPlan::single_device(5, 3, Device::cpu(0));
550        let summary = plan.summary();
551
552        assert!(summary.contains("Placement Plan"));
553        assert!(summary.contains("Nodes: 5"));
554        assert!(summary.contains("Tensors: 3"));
555    }
556}