Skip to main content

tensorlogic_infer/
placement.rs

1//! Device placement and multi-device execution coordination.
2
3use std::collections::HashMap;
4
5use tensorlogic_ir::EinsumGraph;
6
7use crate::capabilities::DeviceType;
8use crate::scheduling::ExecutionSchedule;
9
10/// Device specification with optional id
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Device {
13    pub device_type: DeviceType,
14    pub device_id: usize,
15}
16
17impl Device {
18    pub fn new(device_type: DeviceType, device_id: usize) -> Self {
19        Device {
20            device_type,
21            device_id,
22        }
23    }
24
25    pub fn cpu(id: usize) -> Self {
26        Device::new(DeviceType::CPU, id)
27    }
28
29    pub fn gpu(id: usize) -> Self {
30        Device::new(DeviceType::GPU, id)
31    }
32
33    pub fn default_cpu() -> Self {
34        Device::cpu(0)
35    }
36
37    pub fn as_str(&self) -> String {
38        format!("{}:{}", self.device_type.as_str(), self.device_id)
39    }
40}
41
42/// Placement strategy for multi-device execution
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum PlacementStrategy {
45    /// Place all operations on a single device
46    SingleDevice,
47    /// Round-robin placement across devices
48    RoundRobin,
49    /// Place based on operation cost
50    CostBased,
51    /// Place to minimize data transfer
52    MinimizeTransfer,
53    /// Custom placement via callback
54    Custom,
55}
56
57/// Device placement plan
58#[derive(Debug, Clone)]
59pub struct PlacementPlan {
60    /// Node index -> Device mapping
61    pub node_placement: HashMap<usize, Device>,
62    /// Tensor index -> Device mapping
63    pub tensor_placement: HashMap<usize, Device>,
64    /// Estimated transfer cost
65    pub transfer_cost: f64,
66}
67
68impl PlacementPlan {
69    pub fn new() -> Self {
70        PlacementPlan {
71            node_placement: HashMap::new(),
72            tensor_placement: HashMap::new(),
73            transfer_cost: 0.0,
74        }
75    }
76
77    /// Create a single-device placement plan
78    pub fn single_device(num_nodes: usize, num_tensors: usize, device: Device) -> Self {
79        let mut plan = PlacementPlan::new();
80
81        for i in 0..num_nodes {
82            plan.node_placement.insert(i, device);
83        }
84
85        for i in 0..num_tensors {
86            plan.tensor_placement.insert(i, device);
87        }
88
89        plan
90    }
91
92    /// Get device for a node
93    pub fn get_node_device(&self, node_idx: usize) -> Option<Device> {
94        self.node_placement.get(&node_idx).copied()
95    }
96
97    /// Get device for a tensor
98    pub fn get_tensor_device(&self, tensor_idx: usize) -> Option<Device> {
99        self.tensor_placement.get(&tensor_idx).copied()
100    }
101
102    /// Count number of cross-device transfers
103    pub fn count_transfers(&self, graph: &EinsumGraph) -> usize {
104        let mut transfers = 0;
105
106        // Build a mapping from tensor index to the node that produces it
107        let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
108        for (node_idx, node) in graph.nodes.iter().enumerate() {
109            for &output_idx in &node.outputs {
110                tensor_producers.insert(output_idx, node_idx);
111            }
112        }
113
114        for (node_idx, node) in graph.nodes.iter().enumerate() {
115            let node_device = self.get_node_device(node_idx);
116
117            for &input_idx in &node.inputs {
118                // Determine the device of the input tensor
119                let input_device = if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
120                    // Tensor is produced by another node
121                    self.get_node_device(producer_idx)
122                } else {
123                    // Tensor is an input tensor
124                    self.get_tensor_device(input_idx)
125                };
126
127                if node_device != input_device {
128                    transfers += 1;
129                }
130            }
131        }
132
133        transfers
134    }
135
136    /// Get list of all devices used in this plan
137    pub fn devices(&self) -> Vec<Device> {
138        let mut devices: Vec<_> = self.node_placement.values().copied().collect();
139        devices.sort_by(|a, b| {
140            a.device_id
141                .cmp(&b.device_id)
142                .then_with(|| format!("{:?}", a.device_type).cmp(&format!("{:?}", b.device_type)))
143        });
144        devices.dedup();
145        devices
146    }
147
148    /// Summary of the placement plan
149    pub fn summary(&self) -> String {
150        let devices = self.devices();
151        format!(
152            "Placement Plan:\n\
153             - Nodes: {}\n\
154             - Tensors: {}\n\
155             - Devices: {} ({:?})\n\
156             - Transfer cost: {:.2}",
157            self.node_placement.len(),
158            self.tensor_placement.len(),
159            devices.len(),
160            devices.iter().map(|d| d.as_str()).collect::<Vec<_>>(),
161            self.transfer_cost
162        )
163    }
164}
165
166impl Default for PlacementPlan {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172/// Device placement optimizer
173pub struct PlacementOptimizer {
174    strategy: PlacementStrategy,
175    available_devices: Vec<Device>,
176}
177
178impl PlacementOptimizer {
179    pub fn new(strategy: PlacementStrategy, available_devices: Vec<Device>) -> Self {
180        PlacementOptimizer {
181            strategy,
182            available_devices,
183        }
184    }
185
186    /// Create optimizer for single device
187    pub fn single_device(device: Device) -> Self {
188        PlacementOptimizer {
189            strategy: PlacementStrategy::SingleDevice,
190            available_devices: vec![device],
191        }
192    }
193
194    /// Compute placement plan for a graph
195    pub fn place(&self, graph: &EinsumGraph) -> PlacementPlan {
196        match self.strategy {
197            PlacementStrategy::SingleDevice => self.place_single_device(graph),
198            PlacementStrategy::RoundRobin => self.place_round_robin(graph),
199            PlacementStrategy::CostBased => self.place_cost_based(graph),
200            PlacementStrategy::MinimizeTransfer => self.place_minimize_transfer(graph),
201            PlacementStrategy::Custom => self.place_single_device(graph), // Fallback
202        }
203    }
204
205    /// Compute placement with an execution schedule
206    pub fn place_with_schedule(
207        &self,
208        graph: &EinsumGraph,
209        schedule: &ExecutionSchedule,
210    ) -> PlacementPlan {
211        let mut plan = self.place(graph);
212
213        // Use schedule's device placement if available
214        for (node_idx, device_type) in &schedule.device_placement {
215            if let Some(device) = self.find_device(*device_type) {
216                plan.node_placement.insert(*node_idx, device);
217            }
218        }
219
220        // Recompute transfer cost
221        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
222
223        plan
224    }
225
226    fn place_single_device(&self, graph: &EinsumGraph) -> PlacementPlan {
227        let device = self
228            .available_devices
229            .first()
230            .copied()
231            .unwrap_or(Device::default_cpu());
232        PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), device)
233    }
234
235    fn place_round_robin(&self, graph: &EinsumGraph) -> PlacementPlan {
236        let mut plan = PlacementPlan::new();
237
238        if self.available_devices.is_empty() {
239            return plan;
240        }
241
242        // Place input tensors
243        for (idx, _) in graph.tensors.iter().enumerate() {
244            let device = self.available_devices[idx % self.available_devices.len()];
245            plan.tensor_placement.insert(idx, device);
246        }
247
248        // Place nodes
249        for (idx, _) in graph.nodes.iter().enumerate() {
250            let device = self.available_devices[idx % self.available_devices.len()];
251            plan.node_placement.insert(idx, device);
252        }
253
254        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
255        plan
256    }
257
258    fn place_cost_based(&self, graph: &EinsumGraph) -> PlacementPlan {
259        use crate::scheduling::NodeCost;
260
261        let mut plan = PlacementPlan::new();
262
263        if self.available_devices.is_empty() {
264            return plan;
265        }
266
267        // Compute node costs
268        let costs: Vec<f64> = graph
269            .nodes
270            .iter()
271            .map(|node| NodeCost::estimate_from_node(node).total_cost())
272            .collect();
273
274        // Use greedy assignment to balance load across devices
275        let mut device_loads = vec![0.0; self.available_devices.len()];
276
277        // Sort nodes by cost (descending)
278        let mut node_order: Vec<_> = costs.iter().enumerate().collect();
279        node_order.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
280
281        // Assign each node to the least loaded device
282        for (node_idx, &cost) in node_order {
283            let min_device_idx = device_loads
284                .iter()
285                .enumerate()
286                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap())
287                .map(|(idx, _)| idx)
288                .unwrap_or(0);
289
290            device_loads[min_device_idx] += cost;
291            plan.node_placement
292                .insert(node_idx, self.available_devices[min_device_idx]);
293        }
294
295        // Place tensors on same device as their first consumer
296        let _num_tensors = graph.tensors.len();
297        for (tensor_idx, _) in graph.tensors.iter().enumerate() {
298            // Find first node that uses this tensor
299            let consumer_device = graph
300                .nodes
301                .iter()
302                .enumerate()
303                .find(|(_, node)| node.inputs.contains(&tensor_idx))
304                .and_then(|(node_idx, _)| plan.node_placement.get(&node_idx))
305                .copied()
306                .unwrap_or(self.available_devices[0]);
307
308            plan.tensor_placement.insert(tensor_idx, consumer_device);
309        }
310
311        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
312        plan
313    }
314
315    fn place_minimize_transfer(&self, graph: &EinsumGraph) -> PlacementPlan {
316        let mut plan = PlacementPlan::new();
317
318        if self.available_devices.is_empty() {
319            return plan;
320        }
321
322        // Start with single device placement
323        let default_device = self.available_devices[0];
324        plan = PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), default_device);
325
326        // Iteratively try to reduce transfers by moving nodes
327        let mut improved = true;
328        let max_iterations = 10;
329        let mut iteration = 0;
330
331        while improved && iteration < max_iterations {
332            improved = false;
333            iteration += 1;
334
335            let current_transfers = plan.count_transfers(graph);
336
337            for node_idx in 0..graph.nodes.len() {
338                let current_device = plan.get_node_device(node_idx).unwrap();
339
340                // Try each alternative device
341                for &candidate_device in &self.available_devices {
342                    if candidate_device == current_device {
343                        continue;
344                    }
345
346                    // Temporarily change placement
347                    plan.node_placement.insert(node_idx, candidate_device);
348                    let new_transfers = plan.count_transfers(graph);
349
350                    if new_transfers < current_transfers {
351                        // Keep the change
352                        improved = true;
353                        break;
354                    } else {
355                        // Revert
356                        plan.node_placement.insert(node_idx, current_device);
357                    }
358                }
359            }
360        }
361
362        plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
363        plan
364    }
365
366    fn estimate_transfer_cost(&self, graph: &EinsumGraph, plan: &PlacementPlan) -> f64 {
367        // Simple cost model: 1.0 per transfer
368        plan.count_transfers(graph) as f64
369    }
370
371    fn find_device(&self, device_type: DeviceType) -> Option<Device> {
372        self.available_devices
373            .iter()
374            .find(|d| d.device_type == device_type)
375            .copied()
376    }
377
378    pub fn strategy(&self) -> PlacementStrategy {
379        self.strategy
380    }
381
382    pub fn available_devices(&self) -> &[Device] {
383        &self.available_devices
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use tensorlogic_ir::{EinsumNode, OpType};
391
392    fn create_test_graph() -> EinsumGraph {
393        let mut graph = EinsumGraph::new();
394        graph.tensors.push("x".to_string());
395        graph.tensors.push("y".to_string());
396        graph.tensors.push("t2".to_string()); // Output of node 0
397        graph.tensors.push("t3".to_string()); // Output of node 1
398
399        graph.nodes.push(EinsumNode {
400            inputs: vec![0, 1],
401            outputs: vec![2],
402            op: OpType::Einsum {
403                spec: "ab,bc->ac".into(),
404            },
405            metadata: None,
406        });
407
408        graph.nodes.push(EinsumNode {
409            inputs: vec![2],
410            outputs: vec![3],
411            op: OpType::ElemUnary { op: "relu".into() },
412            metadata: None,
413        });
414
415        graph
416    }
417
418    #[test]
419    fn test_device_creation() {
420        let cpu = Device::cpu(0);
421        assert_eq!(cpu.device_type, DeviceType::CPU);
422        assert_eq!(cpu.device_id, 0);
423        assert_eq!(cpu.as_str(), "CPU:0");
424
425        let gpu = Device::gpu(1);
426        assert_eq!(gpu.device_type, DeviceType::GPU);
427        assert_eq!(gpu.device_id, 1);
428        assert_eq!(gpu.as_str(), "GPU:1");
429    }
430
431    #[test]
432    fn test_placement_plan_single_device() {
433        let device = Device::cpu(0);
434        let plan = PlacementPlan::single_device(3, 2, device);
435
436        assert_eq!(plan.node_placement.len(), 3);
437        assert_eq!(plan.tensor_placement.len(), 2);
438        assert_eq!(plan.get_node_device(0), Some(device));
439        assert_eq!(plan.get_tensor_device(0), Some(device));
440    }
441
442    #[test]
443    fn test_placement_plan_devices() {
444        let mut plan = PlacementPlan::new();
445        plan.node_placement.insert(0, Device::cpu(0));
446        plan.node_placement.insert(1, Device::gpu(0));
447        plan.node_placement.insert(2, Device::gpu(1));
448
449        let devices = plan.devices();
450        assert!(devices.len() >= 2); // At least CPU and GPU
451    }
452
453    #[test]
454    fn test_single_device_placement() {
455        let graph = create_test_graph();
456        let optimizer = PlacementOptimizer::single_device(Device::cpu(0));
457        let plan = optimizer.place(&graph);
458
459        assert_eq!(plan.node_placement.len(), 2);
460        assert_eq!(plan.count_transfers(&graph), 0); // No transfers on single device
461    }
462
463    #[test]
464    fn test_round_robin_placement() {
465        let graph = create_test_graph();
466        let devices = vec![Device::cpu(0), Device::cpu(1)];
467        let optimizer = PlacementOptimizer::new(PlacementStrategy::RoundRobin, devices);
468        let plan = optimizer.place(&graph);
469
470        assert_eq!(plan.node_placement.len(), 2);
471        // Different nodes should be on different devices
472        let dev0 = plan.get_node_device(0);
473        let dev1 = plan.get_node_device(1);
474        assert_ne!(dev0, dev1);
475    }
476
477    #[test]
478    fn test_cost_based_placement() {
479        let graph = create_test_graph();
480        let devices = vec![Device::cpu(0), Device::gpu(0)];
481        let optimizer = PlacementOptimizer::new(PlacementStrategy::CostBased, devices);
482        let plan = optimizer.place(&graph);
483
484        assert_eq!(plan.node_placement.len(), 2);
485        assert!(plan.transfer_cost >= 0.0);
486    }
487
488    #[test]
489    fn test_minimize_transfer_placement() {
490        let graph = create_test_graph();
491        let devices = vec![Device::cpu(0), Device::cpu(1)];
492        let optimizer = PlacementOptimizer::new(PlacementStrategy::MinimizeTransfer, devices);
493        let plan = optimizer.place(&graph);
494
495        assert_eq!(plan.node_placement.len(), 2);
496        // Should minimize transfers
497        let single_device_plan = PlacementOptimizer::single_device(Device::cpu(0)).place(&graph);
498        assert!(plan.count_transfers(&graph) <= single_device_plan.count_transfers(&graph) + 2);
499    }
500
501    #[test]
502    fn test_transfer_counting() {
503        let graph = create_test_graph();
504
505        // Single device: no transfers
506        let plan1 = PlacementPlan::single_device(2, 4, Device::cpu(0));
507        assert_eq!(plan1.count_transfers(&graph), 0);
508
509        // Different devices: some transfers
510        let mut plan2 = PlacementPlan::new();
511        plan2.node_placement.insert(0, Device::cpu(0));
512        plan2.node_placement.insert(1, Device::gpu(0));
513        plan2.tensor_placement.insert(0, Device::cpu(0));
514        plan2.tensor_placement.insert(1, Device::cpu(0));
515        plan2.tensor_placement.insert(2, Device::cpu(0)); // Output of node 0
516        plan2.tensor_placement.insert(3, Device::gpu(0)); // Output of node 1
517
518        let transfers = plan2.count_transfers(&graph);
519        assert!(transfers > 0); // Should have at least one transfer
520    }
521
522    #[test]
523    fn test_placement_strategies() {
524        let strategies = vec![
525            PlacementStrategy::SingleDevice,
526            PlacementStrategy::RoundRobin,
527            PlacementStrategy::CostBased,
528            PlacementStrategy::MinimizeTransfer,
529        ];
530
531        let graph = create_test_graph();
532        let devices = vec![Device::cpu(0), Device::cpu(1)];
533
534        for strategy in strategies {
535            let optimizer = PlacementOptimizer::new(strategy, devices.clone());
536            let plan = optimizer.place(&graph);
537            assert!(
538                !plan.node_placement.is_empty(),
539                "Strategy {:?} failed",
540                strategy
541            );
542        }
543    }
544
545    #[test]
546    fn test_placement_summary() {
547        let plan = PlacementPlan::single_device(5, 3, Device::cpu(0));
548        let summary = plan.summary();
549
550        assert!(summary.contains("Placement Plan"));
551        assert!(summary.contains("Nodes: 5"));
552        assert!(summary.contains("Tensors: 3"));
553    }
554}