1use std::collections::HashMap;
4
5use tensorlogic_ir::EinsumGraph;
6
7use crate::capabilities::DeviceType;
8use crate::scheduling::ExecutionSchedule;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Device {
13 pub device_type: DeviceType,
14 pub device_id: usize,
15}
16
17impl Device {
18 pub fn new(device_type: DeviceType, device_id: usize) -> Self {
19 Device {
20 device_type,
21 device_id,
22 }
23 }
24
25 pub fn cpu(id: usize) -> Self {
26 Device::new(DeviceType::CPU, id)
27 }
28
29 pub fn gpu(id: usize) -> Self {
30 Device::new(DeviceType::GPU, id)
31 }
32
33 pub fn default_cpu() -> Self {
34 Device::cpu(0)
35 }
36
37 pub fn as_str(&self) -> String {
38 format!("{}:{}", self.device_type.as_str(), self.device_id)
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum PlacementStrategy {
45 SingleDevice,
47 RoundRobin,
49 CostBased,
51 MinimizeTransfer,
53 Custom,
55}
56
57#[derive(Debug, Clone)]
59pub struct PlacementPlan {
60 pub node_placement: HashMap<usize, Device>,
62 pub tensor_placement: HashMap<usize, Device>,
64 pub transfer_cost: f64,
66}
67
68impl PlacementPlan {
69 pub fn new() -> Self {
70 PlacementPlan {
71 node_placement: HashMap::new(),
72 tensor_placement: HashMap::new(),
73 transfer_cost: 0.0,
74 }
75 }
76
77 pub fn single_device(num_nodes: usize, num_tensors: usize, device: Device) -> Self {
79 let mut plan = PlacementPlan::new();
80
81 for i in 0..num_nodes {
82 plan.node_placement.insert(i, device);
83 }
84
85 for i in 0..num_tensors {
86 plan.tensor_placement.insert(i, device);
87 }
88
89 plan
90 }
91
92 pub fn get_node_device(&self, node_idx: usize) -> Option<Device> {
94 self.node_placement.get(&node_idx).copied()
95 }
96
97 pub fn get_tensor_device(&self, tensor_idx: usize) -> Option<Device> {
99 self.tensor_placement.get(&tensor_idx).copied()
100 }
101
102 pub fn count_transfers(&self, graph: &EinsumGraph) -> usize {
104 let mut transfers = 0;
105
106 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
108 for (node_idx, node) in graph.nodes.iter().enumerate() {
109 for &output_idx in &node.outputs {
110 tensor_producers.insert(output_idx, node_idx);
111 }
112 }
113
114 for (node_idx, node) in graph.nodes.iter().enumerate() {
115 let node_device = self.get_node_device(node_idx);
116
117 for &input_idx in &node.inputs {
118 let input_device = if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
120 self.get_node_device(producer_idx)
122 } else {
123 self.get_tensor_device(input_idx)
125 };
126
127 if node_device != input_device {
128 transfers += 1;
129 }
130 }
131 }
132
133 transfers
134 }
135
136 pub fn devices(&self) -> Vec<Device> {
138 let mut devices: Vec<_> = self.node_placement.values().copied().collect();
139 devices.sort_by(|a, b| {
140 a.device_id
141 .cmp(&b.device_id)
142 .then_with(|| format!("{:?}", a.device_type).cmp(&format!("{:?}", b.device_type)))
143 });
144 devices.dedup();
145 devices
146 }
147
148 pub fn summary(&self) -> String {
150 let devices = self.devices();
151 format!(
152 "Placement Plan:\n\
153 - Nodes: {}\n\
154 - Tensors: {}\n\
155 - Devices: {} ({:?})\n\
156 - Transfer cost: {:.2}",
157 self.node_placement.len(),
158 self.tensor_placement.len(),
159 devices.len(),
160 devices.iter().map(|d| d.as_str()).collect::<Vec<_>>(),
161 self.transfer_cost
162 )
163 }
164}
165
166impl Default for PlacementPlan {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172pub struct PlacementOptimizer {
174 strategy: PlacementStrategy,
175 available_devices: Vec<Device>,
176}
177
178impl PlacementOptimizer {
179 pub fn new(strategy: PlacementStrategy, available_devices: Vec<Device>) -> Self {
180 PlacementOptimizer {
181 strategy,
182 available_devices,
183 }
184 }
185
186 pub fn single_device(device: Device) -> Self {
188 PlacementOptimizer {
189 strategy: PlacementStrategy::SingleDevice,
190 available_devices: vec![device],
191 }
192 }
193
194 pub fn place(&self, graph: &EinsumGraph) -> PlacementPlan {
196 match self.strategy {
197 PlacementStrategy::SingleDevice => self.place_single_device(graph),
198 PlacementStrategy::RoundRobin => self.place_round_robin(graph),
199 PlacementStrategy::CostBased => self.place_cost_based(graph),
200 PlacementStrategy::MinimizeTransfer => self.place_minimize_transfer(graph),
201 PlacementStrategy::Custom => self.place_single_device(graph), }
203 }
204
205 pub fn place_with_schedule(
207 &self,
208 graph: &EinsumGraph,
209 schedule: &ExecutionSchedule,
210 ) -> PlacementPlan {
211 let mut plan = self.place(graph);
212
213 for (node_idx, device_type) in &schedule.device_placement {
215 if let Some(device) = self.find_device(*device_type) {
216 plan.node_placement.insert(*node_idx, device);
217 }
218 }
219
220 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
222
223 plan
224 }
225
226 fn place_single_device(&self, graph: &EinsumGraph) -> PlacementPlan {
227 let device = self
228 .available_devices
229 .first()
230 .copied()
231 .unwrap_or(Device::default_cpu());
232 PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), device)
233 }
234
235 fn place_round_robin(&self, graph: &EinsumGraph) -> PlacementPlan {
236 let mut plan = PlacementPlan::new();
237
238 if self.available_devices.is_empty() {
239 return plan;
240 }
241
242 for (idx, _) in graph.tensors.iter().enumerate() {
244 let device = self.available_devices[idx % self.available_devices.len()];
245 plan.tensor_placement.insert(idx, device);
246 }
247
248 for (idx, _) in graph.nodes.iter().enumerate() {
250 let device = self.available_devices[idx % self.available_devices.len()];
251 plan.node_placement.insert(idx, device);
252 }
253
254 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
255 plan
256 }
257
258 fn place_cost_based(&self, graph: &EinsumGraph) -> PlacementPlan {
259 use crate::scheduling::NodeCost;
260
261 let mut plan = PlacementPlan::new();
262
263 if self.available_devices.is_empty() {
264 return plan;
265 }
266
267 let costs: Vec<f64> = graph
269 .nodes
270 .iter()
271 .map(|node| NodeCost::estimate_from_node(node).total_cost())
272 .collect();
273
274 let mut device_loads = vec![0.0; self.available_devices.len()];
276
277 let mut node_order: Vec<_> = costs.iter().enumerate().collect();
279 node_order.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
280
281 for (node_idx, &cost) in node_order {
283 let min_device_idx = device_loads
284 .iter()
285 .enumerate()
286 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
287 .map(|(idx, _)| idx)
288 .unwrap_or(0);
289
290 device_loads[min_device_idx] += cost;
291 plan.node_placement
292 .insert(node_idx, self.available_devices[min_device_idx]);
293 }
294
295 let _num_tensors = graph.tensors.len();
297 for (tensor_idx, _) in graph.tensors.iter().enumerate() {
298 let consumer_device = graph
300 .nodes
301 .iter()
302 .enumerate()
303 .find(|(_, node)| node.inputs.contains(&tensor_idx))
304 .and_then(|(node_idx, _)| plan.node_placement.get(&node_idx))
305 .copied()
306 .unwrap_or(self.available_devices[0]);
307
308 plan.tensor_placement.insert(tensor_idx, consumer_device);
309 }
310
311 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
312 plan
313 }
314
315 fn place_minimize_transfer(&self, graph: &EinsumGraph) -> PlacementPlan {
316 let mut plan = PlacementPlan::new();
317
318 if self.available_devices.is_empty() {
319 return plan;
320 }
321
322 let default_device = self.available_devices[0];
324 plan = PlacementPlan::single_device(graph.nodes.len(), graph.tensors.len(), default_device);
325
326 let mut improved = true;
328 let max_iterations = 10;
329 let mut iteration = 0;
330
331 while improved && iteration < max_iterations {
332 improved = false;
333 iteration += 1;
334
335 let current_transfers = plan.count_transfers(graph);
336
337 for node_idx in 0..graph.nodes.len() {
338 let current_device = plan
339 .get_node_device(node_idx)
340 .expect("node was placed before optimization loop");
341
342 for &candidate_device in &self.available_devices {
344 if candidate_device == current_device {
345 continue;
346 }
347
348 plan.node_placement.insert(node_idx, candidate_device);
350 let new_transfers = plan.count_transfers(graph);
351
352 if new_transfers < current_transfers {
353 improved = true;
355 break;
356 } else {
357 plan.node_placement.insert(node_idx, current_device);
359 }
360 }
361 }
362 }
363
364 plan.transfer_cost = self.estimate_transfer_cost(graph, &plan);
365 plan
366 }
367
368 fn estimate_transfer_cost(&self, graph: &EinsumGraph, plan: &PlacementPlan) -> f64 {
369 plan.count_transfers(graph) as f64
371 }
372
373 fn find_device(&self, device_type: DeviceType) -> Option<Device> {
374 self.available_devices
375 .iter()
376 .find(|d| d.device_type == device_type)
377 .copied()
378 }
379
380 pub fn strategy(&self) -> PlacementStrategy {
381 self.strategy
382 }
383
384 pub fn available_devices(&self) -> &[Device] {
385 &self.available_devices
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use tensorlogic_ir::{EinsumNode, OpType};
393
394 fn create_test_graph() -> EinsumGraph {
395 let mut graph = EinsumGraph::new();
396 graph.tensors.push("x".to_string());
397 graph.tensors.push("y".to_string());
398 graph.tensors.push("t2".to_string()); graph.tensors.push("t3".to_string()); graph.nodes.push(EinsumNode {
402 inputs: vec![0, 1],
403 outputs: vec![2],
404 op: OpType::Einsum {
405 spec: "ab,bc->ac".into(),
406 },
407 metadata: None,
408 });
409
410 graph.nodes.push(EinsumNode {
411 inputs: vec![2],
412 outputs: vec![3],
413 op: OpType::ElemUnary { op: "relu".into() },
414 metadata: None,
415 });
416
417 graph
418 }
419
420 #[test]
421 fn test_device_creation() {
422 let cpu = Device::cpu(0);
423 assert_eq!(cpu.device_type, DeviceType::CPU);
424 assert_eq!(cpu.device_id, 0);
425 assert_eq!(cpu.as_str(), "CPU:0");
426
427 let gpu = Device::gpu(1);
428 assert_eq!(gpu.device_type, DeviceType::GPU);
429 assert_eq!(gpu.device_id, 1);
430 assert_eq!(gpu.as_str(), "GPU:1");
431 }
432
433 #[test]
434 fn test_placement_plan_single_device() {
435 let device = Device::cpu(0);
436 let plan = PlacementPlan::single_device(3, 2, device);
437
438 assert_eq!(plan.node_placement.len(), 3);
439 assert_eq!(plan.tensor_placement.len(), 2);
440 assert_eq!(plan.get_node_device(0), Some(device));
441 assert_eq!(plan.get_tensor_device(0), Some(device));
442 }
443
444 #[test]
445 fn test_placement_plan_devices() {
446 let mut plan = PlacementPlan::new();
447 plan.node_placement.insert(0, Device::cpu(0));
448 plan.node_placement.insert(1, Device::gpu(0));
449 plan.node_placement.insert(2, Device::gpu(1));
450
451 let devices = plan.devices();
452 assert!(devices.len() >= 2); }
454
455 #[test]
456 fn test_single_device_placement() {
457 let graph = create_test_graph();
458 let optimizer = PlacementOptimizer::single_device(Device::cpu(0));
459 let plan = optimizer.place(&graph);
460
461 assert_eq!(plan.node_placement.len(), 2);
462 assert_eq!(plan.count_transfers(&graph), 0); }
464
465 #[test]
466 fn test_round_robin_placement() {
467 let graph = create_test_graph();
468 let devices = vec![Device::cpu(0), Device::cpu(1)];
469 let optimizer = PlacementOptimizer::new(PlacementStrategy::RoundRobin, devices);
470 let plan = optimizer.place(&graph);
471
472 assert_eq!(plan.node_placement.len(), 2);
473 let dev0 = plan.get_node_device(0);
475 let dev1 = plan.get_node_device(1);
476 assert_ne!(dev0, dev1);
477 }
478
479 #[test]
480 fn test_cost_based_placement() {
481 let graph = create_test_graph();
482 let devices = vec![Device::cpu(0), Device::gpu(0)];
483 let optimizer = PlacementOptimizer::new(PlacementStrategy::CostBased, devices);
484 let plan = optimizer.place(&graph);
485
486 assert_eq!(plan.node_placement.len(), 2);
487 assert!(plan.transfer_cost >= 0.0);
488 }
489
490 #[test]
491 fn test_minimize_transfer_placement() {
492 let graph = create_test_graph();
493 let devices = vec![Device::cpu(0), Device::cpu(1)];
494 let optimizer = PlacementOptimizer::new(PlacementStrategy::MinimizeTransfer, devices);
495 let plan = optimizer.place(&graph);
496
497 assert_eq!(plan.node_placement.len(), 2);
498 let single_device_plan = PlacementOptimizer::single_device(Device::cpu(0)).place(&graph);
500 assert!(plan.count_transfers(&graph) <= single_device_plan.count_transfers(&graph) + 2);
501 }
502
503 #[test]
504 fn test_transfer_counting() {
505 let graph = create_test_graph();
506
507 let plan1 = PlacementPlan::single_device(2, 4, Device::cpu(0));
509 assert_eq!(plan1.count_transfers(&graph), 0);
510
511 let mut plan2 = PlacementPlan::new();
513 plan2.node_placement.insert(0, Device::cpu(0));
514 plan2.node_placement.insert(1, Device::gpu(0));
515 plan2.tensor_placement.insert(0, Device::cpu(0));
516 plan2.tensor_placement.insert(1, Device::cpu(0));
517 plan2.tensor_placement.insert(2, Device::cpu(0)); plan2.tensor_placement.insert(3, Device::gpu(0)); let transfers = plan2.count_transfers(&graph);
521 assert!(transfers > 0); }
523
524 #[test]
525 fn test_placement_strategies() {
526 let strategies = vec![
527 PlacementStrategy::SingleDevice,
528 PlacementStrategy::RoundRobin,
529 PlacementStrategy::CostBased,
530 PlacementStrategy::MinimizeTransfer,
531 ];
532
533 let graph = create_test_graph();
534 let devices = vec![Device::cpu(0), Device::cpu(1)];
535
536 for strategy in strategies {
537 let optimizer = PlacementOptimizer::new(strategy, devices.clone());
538 let plan = optimizer.place(&graph);
539 assert!(
540 !plan.node_placement.is_empty(),
541 "Strategy {:?} failed",
542 strategy
543 );
544 }
545 }
546
547 #[test]
548 fn test_placement_summary() {
549 let plan = PlacementPlan::single_device(5, 3, Device::cpu(0));
550 let summary = plan.summary();
551
552 assert!(summary.contains("Placement Plan"));
553 assert!(summary.contains("Nodes: 5"));
554 assert!(summary.contains("Tensors: 3"));
555 }
556}