1use std::collections::HashMap;
4
5use tensorlogic_ir::EinsumGraph;
6
7use crate::capabilities::DeviceType;
8use crate::scheduling::ExecutionSchedule;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Device {
13 pub device_type: DeviceType,
14 pub device_id: usize,
15}
16
17impl Device {
18 pub fn new(device_type: DeviceType, device_id: usize) -> Self {
19 Device {
20 device_type,
21 device_id,
22 }
23 }
24
25 pub fn cpu(id: usize) -> Self {
26 Device::new(DeviceType::CPU, id)
27 }
28
29 pub fn gpu(id: usize) -> Self {
30 Device::new(DeviceType::GPU, id)
31 }
32
33 pub fn default_cpu() -> Self {
34 Device::cpu(0)
35 }
36
37 pub fn as_str(&self) -> String {
38 format!("{}:{}", self.device_type.as_str(), self.device_id)
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum PlacementStrategy {
45 SingleDevice,
47 RoundRobin,
49 CostBased,
51 MinimizeTransfer,
53 Custom,
55}
56
57#[derive(Debug, Clone)]
59pub struct PlacementPlan {
60 pub node_placement: HashMap<usize, Device>,
62 pub tensor_placement: HashMap<usize, Device>,
64 pub transfer_cost: f64,
66}
67
68impl PlacementPlan {
69 pub fn new() -> Self {
70 PlacementPlan {
71 node_placement: HashMap::new(),
72 tensor_placement: HashMap::new(),
73 transfer_cost: 0.0,
74 }
75 }
76
77 pub fn single_device(num_nodes: usize, num_tensors: usize, device: Device) -> Self {
79 let mut plan = PlacementPlan::new();
80
81 for i in 0..num_nodes {
82 plan.node_placement.insert(i, device);
83 }
84
85 for i in 0..num_tensors {
86 plan.tensor_placement.insert(i, device);
87 }
88
89 plan
90 }
91
92 pub fn get_node_device(&self, node_idx: usize) -> Option<Device> {
94 self.node_placement.get(&node_idx).copied()
95 }
96
97 pub fn get_tensor_device(&self, tensor_idx: usize) -> Option<Device> {
99 self.tensor_placement.get(&tensor_idx).copied()
100 }
101
102 pub fn count_transfers(&self, graph: &EinsumGraph) -> usize {
104 let mut transfers = 0;
105
106 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
108 for (node_idx, node) in graph.nodes.iter().enumerate() {
109 for &output_idx in &node.outputs {
110 tensor_producers.insert(output_idx, node_idx);
111 }
112 }
113
114 for (node_idx, node) in graph.nodes.iter().enumerate() {
115 let node_device = self.get_node_device(node_idx);
116
117 for &input_idx in &node.inputs {
118 let input_device = if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
120 self.get_node_device(producer_idx)
122 } else {
123 self.get_tensor_device(input_idx)
125 };
126
127 if node_device != input_device {
128 transfers += 1;
129 }
130 }
131 }
132
133 transfers
134 }
135
136 pub fn devices(&self) -> Vec<Device> {
138 let mut devices: Vec<_> = self.node_placement.values().copied().collect();
139 devices.sort_by(|a, b| {
140 a.device_id
141 .cmp(&b.device_id)
142 .then_with(|| format!("{:?}", a.device_type).cmp(&format!("{:?}", b.device_type)))
143 });
144 devices.dedup();
145 devices
146 }
147
148 pub fn summary(&self) -> String {
150 let devices = self.devices();
151 format!(
152 "Placement Plan:\n\
153 - Nodes: {}\n\
154 - Tensors: {}\n\
155 - Devices: {} ({:?})\n\
156 - Transfer cost: {:.2}",
157 self.node_placement.len(),
158 self.tensor_placement.len(),
159 devices.len(),
160 devices.iter().map(|d| d.as_str()).collect::<Vec<_>>(),
161 self.transfer_cost
162 )
163 }
164}
165
166impl Default for PlacementPlan {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172pub struct PlacementOptimizer {
174 strategy: PlacementStrategy,
175 available_devices: Vec<Device>,
176}
177
178impl PlacementOptimizer {
179 pub fn new(strategy: PlacementStrategy, available_devices: Vec<Device>) -> Self {
180 PlacementOptimizer {
181 strategy,
182 available_devices,
183 }
184 }
185
186 pub fn single_device(device: Device) -> Self {
188 PlacementOptimizer {
189 strategy: PlacementStrategy::SingleDevice,
190 available_devices: vec![device],
191 }
192 }
193
194 pub fn place(&self, graph: &EinsumGraph) -> PlacementPlan {
196 match self.strategy {
197 PlacementStrategy::SingleDevice => self.place_single_device(graph),
198 PlacementStrategy::RoundRobin => self.place_round_robin(graph),
199 PlacementStrategy::CostBased => self.place_cost_based(graph),
200 PlacementStrategy::MinimizeTransfer => self.place_minimize_transfer(graph),
201 PlacementStrategy::Custom => self.place_single_device(graph), }
203 }
204
205 pub fn place_with_schedule(
207 &self,
208 graph: &EinsumGraph,
209 schedule: &ExecutionSchedule,
210 ) -> PlacementPlan {
211 let mut plan = self.place(graph);
212
213 for (node_idx, device_type) in &schedule.device_placement {
215 if let Some(device) = self.find_device(*device_type) {
216 plan.node_placement.insert(*node_idx, device);
217 }
218 }
219
220 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
222
223 plan
224 }
225
226 fn place_single_device(&self, graph: &EinsumGraph) -> PlacementPlan {
227 let device = self
228 .available_devices
229 .first()
230 .copied()
231 .unwrap_or(Device::default_cpu());
232 PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), device)
233 }
234
235 fn place_round_robin(&self, graph: &EinsumGraph) -> PlacementPlan {
236 let mut plan = PlacementPlan::new();
237
238 if self.available_devices.is_empty() {
239 return plan;
240 }
241
242 for (idx, _) in graph.tensors.iter().enumerate() {
244 let device = self.available_devices[idx % self.available_devices.len()];
245 plan.tensor_placement.insert(idx, device);
246 }
247
248 for (idx, _) in graph.nodes.iter().enumerate() {
250 let device = self.available_devices[idx % self.available_devices.len()];
251 plan.node_placement.insert(idx, device);
252 }
253
254 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
255 plan
256 }
257
258 fn place_cost_based(&self, graph: &EinsumGraph) -> PlacementPlan {
259 use crate::scheduling::NodeCost;
260
261 let mut plan = PlacementPlan::new();
262
263 if self.available_devices.is_empty() {
264 return plan;
265 }
266
267 let costs: Vec<f64> = graph
269 .nodes
270 .iter()
271 .map(|node| NodeCost::estimate_from_node(node).total_cost())
272 .collect();
273
274 let mut device_loads = vec![0.0; self.available_devices.len()];
276
277 let mut node_order: Vec<_> = costs.iter().enumerate().collect();
279 node_order.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
280
281 for (node_idx, &cost) in node_order {
283 let min_device_idx = device_loads
284 .iter()
285 .enumerate()
286 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap())
287 .map(|(idx, _)| idx)
288 .unwrap_or(0);
289
290 device_loads[min_device_idx] += cost;
291 plan.node_placement
292 .insert(node_idx, self.available_devices[min_device_idx]);
293 }
294
295 let _num_tensors = graph.tensors.len();
297 for (tensor_idx, _) in graph.tensors.iter().enumerate() {
298 let consumer_device = graph
300 .nodes
301 .iter()
302 .enumerate()
303 .find(|(_, node)| node.inputs.contains(&tensor_idx))
304 .and_then(|(node_idx, _)| plan.node_placement.get(&node_idx))
305 .copied()
306 .unwrap_or(self.available_devices[0]);
307
308 plan.tensor_placement.insert(tensor_idx, consumer_device);
309 }
310
311 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
312 plan
313 }
314
315 fn place_minimize_transfer(&self, graph: &EinsumGraph) -> PlacementPlan {
316 let mut plan = PlacementPlan::new();
317
318 if self.available_devices.is_empty() {
319 return plan;
320 }
321
322 let default_device = self.available_devices[0];
324 plan = PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), default_device);
325
326 let mut improved = true;
328 let max_iterations = 10;
329 let mut iteration = 0;
330
331 while improved && iteration < max_iterations {
332 improved = false;
333 iteration += 1;
334
335 let current_transfers = plan.count_transfers(graph);
336
337 for node_idx in 0..graph.nodes.len() {
338 let current_device = plan.get_node_device(node_idx).unwrap();
339
340 for &candidate_device in &self.available_devices {
342 if candidate_device == current_device {
343 continue;
344 }
345
346 plan.node_placement.insert(node_idx, candidate_device);
348 let new_transfers = plan.count_transfers(graph);
349
350 if new_transfers < current_transfers {
351 improved = true;
353 break;
354 } else {
355 plan.node_placement.insert(node_idx, current_device);
357 }
358 }
359 }
360 }
361
362 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
363 plan
364 }
365
366 fn estimate_transfer_cost(&self, graph: &EinsumGraph, plan: &PlacementPlan) -> f64 {
367 plan.count_transfers(graph) as f64
369 }
370
371 fn find_device(&self, device_type: DeviceType) -> Option<Device> {
372 self.available_devices
373 .iter()
374 .find(|d| d.device_type == device_type)
375 .copied()
376 }
377
378 pub fn strategy(&self) -> PlacementStrategy {
379 self.strategy
380 }
381
382 pub fn available_devices(&self) -> &[Device] {
383 &self.available_devices
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use tensorlogic_ir::{EinsumNode, OpType};
391
392 fn create_test_graph() -> EinsumGraph {
393 let mut graph = EinsumGraph::new();
394 graph.tensors.push("x".to_string());
395 graph.tensors.push("y".to_string());
396 graph.tensors.push("t2".to_string()); graph.tensors.push("t3".to_string()); graph.nodes.push(EinsumNode {
400 inputs: vec![0, 1],
401 outputs: vec![2],
402 op: OpType::Einsum {
403 spec: "ab,bc->ac".into(),
404 },
405 metadata: None,
406 });
407
408 graph.nodes.push(EinsumNode {
409 inputs: vec![2],
410 outputs: vec![3],
411 op: OpType::ElemUnary { op: "relu".into() },
412 metadata: None,
413 });
414
415 graph
416 }
417
418 #[test]
419 fn test_device_creation() {
420 let cpu = Device::cpu(0);
421 assert_eq!(cpu.device_type, DeviceType::CPU);
422 assert_eq!(cpu.device_id, 0);
423 assert_eq!(cpu.as_str(), "CPU:0");
424
425 let gpu = Device::gpu(1);
426 assert_eq!(gpu.device_type, DeviceType::GPU);
427 assert_eq!(gpu.device_id, 1);
428 assert_eq!(gpu.as_str(), "GPU:1");
429 }
430
431 #[test]
432 fn test_placement_plan_single_device() {
433 let device = Device::cpu(0);
434 let plan = PlacementPlan::single_device(3, 2, device);
435
436 assert_eq!(plan.node_placement.len(), 3);
437 assert_eq!(plan.tensor_placement.len(), 2);
438 assert_eq!(plan.get_node_device(0), Some(device));
439 assert_eq!(plan.get_tensor_device(0), Some(device));
440 }
441
442 #[test]
443 fn test_placement_plan_devices() {
444 let mut plan = PlacementPlan::new();
445 plan.node_placement.insert(0, Device::cpu(0));
446 plan.node_placement.insert(1, Device::gpu(0));
447 plan.node_placement.insert(2, Device::gpu(1));
448
449 let devices = plan.devices();
450 assert!(devices.len() >= 2); }
452
453 #[test]
454 fn test_single_device_placement() {
455 let graph = create_test_graph();
456 let optimizer = PlacementOptimizer::single_device(Device::cpu(0));
457 let plan = optimizer.place(&graph);
458
459 assert_eq!(plan.node_placement.len(), 2);
460 assert_eq!(plan.count_transfers(&graph), 0); }
462
463 #[test]
464 fn test_round_robin_placement() {
465 let graph = create_test_graph();
466 let devices = vec![Device::cpu(0), Device::cpu(1)];
467 let optimizer = PlacementOptimizer::new(PlacementStrategy::RoundRobin, devices);
468 let plan = optimizer.place(&graph);
469
470 assert_eq!(plan.node_placement.len(), 2);
471 let dev0 = plan.get_node_device(0);
473 let dev1 = plan.get_node_device(1);
474 assert_ne!(dev0, dev1);
475 }
476
477 #[test]
478 fn test_cost_based_placement() {
479 let graph = create_test_graph();
480 let devices = vec![Device::cpu(0), Device::gpu(0)];
481 let optimizer = PlacementOptimizer::new(PlacementStrategy::CostBased, devices);
482 let plan = optimizer.place(&graph);
483
484 assert_eq!(plan.node_placement.len(), 2);
485 assert!(plan.transfer_cost >= 0.0);
486 }
487
488 #[test]
489 fn test_minimize_transfer_placement() {
490 let graph = create_test_graph();
491 let devices = vec![Device::cpu(0), Device::cpu(1)];
492 let optimizer = PlacementOptimizer::new(PlacementStrategy::MinimizeTransfer, devices);
493 let plan = optimizer.place(&graph);
494
495 assert_eq!(plan.node_placement.len(), 2);
496 let single_device_plan = PlacementOptimizer::single_device(Device::cpu(0)).place(&graph);
498 assert!(plan.count_transfers(&graph) <= single_device_plan.count_transfers(&graph) + 2);
499 }
500
501 #[test]
502 fn test_transfer_counting() {
503 let graph = create_test_graph();
504
505 let plan1 = PlacementPlan::single_device(2, 4, Device::cpu(0));
507 assert_eq!(plan1.count_transfers(&graph), 0);
508
509 let mut plan2 = PlacementPlan::new();
511 plan2.node_placement.insert(0, Device::cpu(0));
512 plan2.node_placement.insert(1, Device::gpu(0));
513 plan2.tensor_placement.insert(0, Device::cpu(0));
514 plan2.tensor_placement.insert(1, Device::cpu(0));
515 plan2.tensor_placement.insert(2, Device::cpu(0)); plan2.tensor_placement.insert(3, Device::gpu(0)); let transfers = plan2.count_transfers(&graph);
519 assert!(transfers > 0); }
521
522 #[test]
523 fn test_placement_strategies() {
524 let strategies = vec![
525 PlacementStrategy::SingleDevice,
526 PlacementStrategy::RoundRobin,
527 PlacementStrategy::CostBased,
528 PlacementStrategy::MinimizeTransfer,
529 ];
530
531 let graph = create_test_graph();
532 let devices = vec![Device::cpu(0), Device::cpu(1)];
533
534 for strategy in strategies {
535 let optimizer = PlacementOptimizer::new(strategy, devices.clone());
536 let plan = optimizer.place(&graph);
537 assert!(
538 !plan.node_placement.is_empty(),
539 "Strategy {:?} failed",
540 strategy
541 );
542 }
543 }
544
545 #[test]
546 fn test_placement_summary() {
547 let plan = PlacementPlan::single_device(5, 3, Device::cpu(0));
548 let summary = plan.summary();
549
550 assert!(summary.contains("Placement Plan"));
551 assert!(summary.contains("Nodes: 5"));
552 assert!(summary.contains("Tensors: 3"));
553 }
554}