1use crate::{Error, Result};
18use arrow::array::Int32Array;
19use wgpu;
20
21#[derive(Debug, Clone)]
23pub struct GpuDeviceInfo {
24 pub name: String,
26 pub device_type: wgpu::DeviceType,
28 pub backend: wgpu::Backend,
30}
31
32pub struct MultiGpuManager {
34 devices: Vec<GpuDeviceInfo>,
36}
37
38impl MultiGpuManager {
39 pub fn new() -> Result<Self> {
44 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
45 backends: wgpu::Backends::all(),
46 ..Default::default()
47 });
48
49 let adapters = instance.enumerate_adapters(wgpu::Backends::all());
51
52 let devices: Vec<GpuDeviceInfo> = adapters
54 .iter()
55 .map(|adapter| {
56 let info = adapter.get_info();
57 GpuDeviceInfo {
58 name: info.name,
59 device_type: info.device_type,
60 backend: info.backend,
61 }
62 })
63 .collect();
64
65 Ok(Self { devices })
66 }
67
68 #[must_use]
70 pub fn device_count(&self) -> usize {
71 self.devices.len()
72 }
73
74 #[must_use]
76 pub fn devices(&self) -> &[GpuDeviceInfo] {
77 &self.devices
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum PartitionStrategy {
84 Range,
87 Hash,
90}
91
92#[derive(Debug)]
94pub struct DataPartition {
95 pub device_id: usize,
97 pub data: Int32Array,
99}
100
101pub fn partition_data(
114 data: &Int32Array,
115 num_partitions: usize,
116 strategy: PartitionStrategy,
117) -> Result<Vec<DataPartition>> {
118 if num_partitions == 0 {
119 return Err(Error::InvalidInput("num_partitions must be > 0".to_string()));
120 }
121
122 let partitions = match strategy {
123 PartitionStrategy::Range => partition_range(data, num_partitions),
124 PartitionStrategy::Hash => partition_hash(data, num_partitions),
125 };
126 Ok(partitions)
127}
128
129fn partition_range(data: &Int32Array, num_partitions: usize) -> Vec<DataPartition> {
131 let len = data.len();
132 let mut partitions = Vec::with_capacity(num_partitions);
133
134 let base_size = len / num_partitions;
136 let remainder = len % num_partitions;
137
138 let mut offset = 0;
139 for device_id in 0..num_partitions {
140 let size = if device_id < remainder { base_size + 1 } else { base_size };
142
143 let values: Vec<i32> = (offset..offset + size).map(|i| data.value(i)).collect();
145
146 partitions.push(DataPartition { device_id, data: Int32Array::from(values) });
147
148 offset += size;
149 }
150
151 partitions
152}
153
154fn partition_hash(data: &Int32Array, num_partitions: usize) -> Vec<DataPartition> {
156 use std::collections::hash_map::DefaultHasher;
157 use std::hash::{Hash, Hasher};
158
159 let mut buckets: Vec<Vec<i32>> = (0..num_partitions).map(|_| Vec::new()).collect();
161
162 for i in 0..data.len() {
164 let value = data.value(i);
165
166 let mut hasher = DefaultHasher::new();
168 i.hash(&mut hasher);
169 let hash = hasher.finish();
170
171 #[allow(clippy::cast_possible_truncation)]
172 let partition_id = (hash % num_partitions as u64) as usize;
173 buckets[partition_id].push(value);
174 }
175
176 let partitions: Vec<DataPartition> = buckets
178 .into_iter()
179 .enumerate()
180 .map(|(device_id, values)| DataPartition { device_id, data: Int32Array::from(values) })
181 .collect();
182
183 partitions
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_multigpu_device_detection() {
192 let manager = MultiGpuManager::new();
194
195 match manager {
198 Ok(mgr) => {
199 let count = mgr.device_count();
201 println!("Detected {count} GPU device(s)");
202
203 if count > 0 {
205 for (i, device) in mgr.devices().iter().enumerate() {
206 println!("GPU {i}: {device:?}");
207 assert!(!device.name.is_empty(), "Device name should not be empty");
208 }
209 }
210 }
211 Err(e) => {
212 panic!("MultiGpuManager::new() failed: {e}");
213 }
214 }
215 }
216
217 #[test]
218 fn test_multigpu_device_count_zero_when_no_gpu() {
219 let manager = MultiGpuManager::new();
222
223 if let Ok(mgr) = manager {
224 let _count = mgr.device_count();
227 } else {
228 }
231 }
232
233 #[test]
234 fn test_partition_range_even_split() {
235 let data = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
237 let partitions = partition_data(&data, 2, PartitionStrategy::Range).unwrap();
238
239 assert_eq!(partitions.len(), 2);
240
241 assert_eq!(partitions[0].device_id, 0);
243 assert_eq!(partitions[0].data.len(), 4);
244 assert_eq!(partitions[0].data.value(0), 1);
245 assert_eq!(partitions[0].data.value(3), 4);
246
247 assert_eq!(partitions[1].device_id, 1);
249 assert_eq!(partitions[1].data.len(), 4);
250 assert_eq!(partitions[1].data.value(0), 5);
251 assert_eq!(partitions[1].data.value(3), 8);
252 }
253
254 #[test]
255 fn test_partition_range_uneven_split() {
256 let data = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
259 let partitions = partition_data(&data, 3, PartitionStrategy::Range).unwrap();
260
261 assert_eq!(partitions.len(), 3);
262
263 let total_len: usize = partitions.iter().map(|p| p.data.len()).sum();
265 assert_eq!(total_len, 10);
266
267 assert_eq!(partitions[0].data.value(0), 1); let last_partition = &partitions[2];
270 assert_eq!(last_partition.data.value(last_partition.data.len() - 1), 10);
271 }
273
274 #[test]
275 fn test_partition_hash_distribution() {
276 let data = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
278 let partitions = partition_data(&data, 2, PartitionStrategy::Hash).unwrap();
279
280 assert_eq!(partitions.len(), 2);
281
282 let total_len: usize = partitions.iter().map(|p| p.data.len()).sum();
284 assert_eq!(total_len, 8);
285
286 assert_eq!(partitions[0].device_id, 0);
289 assert_eq!(partitions[1].device_id, 1);
290 }
291
292 #[test]
293 fn test_partition_single_gpu() {
294 let data = Int32Array::from(vec![1, 2, 3, 4]);
297 let partitions = partition_data(&data, 1, PartitionStrategy::Range).unwrap();
298
299 assert_eq!(partitions.len(), 1);
300 assert_eq!(partitions[0].device_id, 0);
301 assert_eq!(partitions[0].data.len(), 4);
302 }
303
304 #[test]
305 fn test_partition_empty_data() {
306 let data = Int32Array::from(vec![] as Vec<i32>);
308 let partitions = partition_data(&data, 2, PartitionStrategy::Range).unwrap();
309
310 assert_eq!(partitions.len(), 2);
311 assert_eq!(partitions[0].data.len(), 0);
313 assert_eq!(partitions[1].data.len(), 0);
314 }
315
316 #[test]
317 fn test_partition_zero_partitions_error() {
318 let data = Int32Array::from(vec![1, 2, 3]);
321 let result = partition_data(&data, 0, PartitionStrategy::Range);
322
323 assert!(result.is_err());
324 assert!(result.unwrap_err().to_string().contains("num_partitions must be > 0"));
325 }
326}