1use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use crate::error::{NdimageError, NdimageResult};
11
12#[derive(Debug, Clone)]
14pub struct DeviceCapability {
15 pub name: String,
17 pub total_memory: usize,
19 pub available_memory: usize,
21 pub compute_capability: Option<(u32, u32)>,
23 pub max_threads_per_block: Option<usize>,
25 pub max_block_dims: Option<[usize; 3]>,
27 pub max_grid_dims: Option<[usize; 3]>,
29 pub shared_memory_per_block: Option<usize>,
31 pub multiprocessor_count: Option<usize>,
33 pub clock_rate: Option<usize>,
35 pub memory_bandwidth: Option<f64>,
37}
38
39impl Default for DeviceCapability {
40 fn default() -> Self {
41 Self {
42 name: "Unknown Device".to_string(),
43 total_memory: 0,
44 available_memory: 0,
45 compute_capability: None,
46 max_threads_per_block: None,
47 max_block_dims: None,
48 max_grid_dims: None,
49 shared_memory_per_block: None,
50 multiprocessor_count: None,
51 clock_rate: None,
52 memory_bandwidth: None,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct SystemCapabilities {
60 pub cuda_available: bool,
61 pub opencl_available: bool,
62 pub metal_available: bool,
63 pub gpu_available: bool,
64 pub gpu_memory_mb: usize,
65 pub compute_units: u32,
66}
67
68pub struct DeviceManager {
70 #[cfg(feature = "cuda")]
71 cuda_devices: Vec<DeviceCapability>,
72 #[cfg(feature = "opencl")]
73 opencl_devices: Vec<DeviceCapability>,
74 #[cfg(all(target_os = "macos", feature = "metal"))]
75 metal_devices: Vec<DeviceCapability>,
76}
77
78impl DeviceManager {
79 pub fn new() -> NdimageResult<Self> {
81 let mut manager = Self {
82 #[cfg(feature = "cuda")]
83 cuda_devices: Vec::new(),
84 #[cfg(feature = "opencl")]
85 opencl_devices: Vec::new(),
86 #[cfg(all(target_os = "macos", feature = "metal"))]
87 metal_devices: Vec::new(),
88 };
89
90 #[cfg(feature = "cuda")]
92 {
93 manager.cuda_devices = detect_cuda_devices()?;
94 }
95
96 #[cfg(feature = "opencl")]
97 {
98 manager.opencl_devices = detect_opencl_devices()?;
99 }
100
101 #[cfg(all(target_os = "macos", feature = "metal"))]
102 {
103 manager.metal_devices = detect_metal_devices()?;
104 }
105
106 Ok(manager)
107 }
108
109 pub fn get_best_device(&self, requiredmemory: usize) -> Option<(super::Backend, usize)> {
111 let mut best_device = None;
112 let mut best_score = 0.0;
113
114 #[cfg(feature = "cuda")]
115 {
116 for (idx, device) in self.cuda_devices.iter().enumerate() {
117 if device.available_memory >= requiredmemory {
118 let score = self.calculate_device_score(device);
119 if score > best_score {
120 best_score = score;
121 best_device = Some((super::Backend::Cuda, idx));
122 }
123 }
124 }
125 }
126
127 #[cfg(feature = "opencl")]
128 {
129 for (idx, device) in self.opencl_devices.iter().enumerate() {
130 if device.available_memory >= requiredmemory {
131 let score = self.calculate_device_score(device) * 0.9; if score > best_score {
133 best_score = score;
134 best_device = Some((super::Backend::OpenCL, idx));
135 }
136 }
137 }
138 }
139
140 #[cfg(all(target_os = "macos", feature = "metal"))]
141 {
142 for (idx, device) in self.metal_devices.iter().enumerate() {
143 if device.available_memory >= requiredmemory {
144 let score = self.calculate_device_score(device) * 0.8; if score > best_score {
146 best_score = score;
147 best_device = Some((super::Backend::Metal, idx));
148 }
149 }
150 }
151 }
152
153 best_device
154 }
155
156 fn calculate_device_score(&self, device: &DeviceCapability) -> f64 {
158 let mut score = 0.0;
159
160 score += (device.total_memory as f64) / (1024.0 * 1024.0 * 1024.0) * 10.0;
162
163 if let Some(mp_count) = device.multiprocessor_count {
165 score += (mp_count as f64) * 5.0;
166 }
167
168 if let Some(clock) = device.clock_rate {
170 score += (clock as f64) / 1_000_000.0 * 3.0;
171 }
172
173 if let Some(bandwidth) = device.memory_bandwidth {
175 score += bandwidth * 0.1;
176 }
177
178 score
179 }
180
181 pub fn get_device_info(
183 &self,
184 backend: super::Backend,
185 device_id: usize,
186 ) -> Option<&DeviceCapability> {
187 match backend {
188 #[cfg(feature = "cuda")]
189 super::Backend::Cuda => self.cuda_devices.get(device_id),
190 #[cfg(feature = "opencl")]
191 super::Backend::OpenCL => self.opencl_devices.get(device_id),
192 #[cfg(all(target_os = "macos", feature = "metal"))]
193 super::Backend::Metal => self.metal_devices.get(device_id),
194 _ => None,
195 }
196 }
197
198 pub fn is_backend_available(&self, backend: super::Backend) -> bool {
200 match backend {
201 #[cfg(feature = "cuda")]
202 super::Backend::Cuda => !self.cuda_devices.is_empty(),
203 #[cfg(feature = "opencl")]
204 super::Backend::OpenCL => !self.opencl_devices.is_empty(),
205 #[cfg(all(target_os = "macos", feature = "metal"))]
206 super::Backend::Metal => !self.metal_devices.is_empty(),
207 super::Backend::Cpu => true,
208 super::Backend::Auto => {
209 #[cfg(feature = "cuda")]
210 if !self.cuda_devices.is_empty() {
211 return true;
212 }
213 #[cfg(feature = "opencl")]
214 if !self.opencl_devices.is_empty() {
215 return true;
216 }
217 #[cfg(all(target_os = "macos", feature = "metal"))]
218 if !self.metal_devices.is_empty() {
219 return true;
220 }
221 true }
223 }
224 }
225
226 pub fn device_count(&self, backend: super::Backend) -> usize {
228 match backend {
229 #[cfg(feature = "cuda")]
230 super::Backend::Cuda => self.cuda_devices.len(),
231 #[cfg(feature = "opencl")]
232 super::Backend::OpenCL => self.opencl_devices.len(),
233 #[cfg(all(target_os = "macos", feature = "metal"))]
234 super::Backend::Metal => self.metal_devices.len(),
235 super::Backend::Cpu => 1,
236 super::Backend::Auto => {
237 let mut total = 1; #[cfg(feature = "cuda")]
239 {
240 total += self.cuda_devices.len();
241 }
242 #[cfg(feature = "opencl")]
243 {
244 total += self.opencl_devices.len();
245 }
246 #[cfg(all(target_os = "macos", feature = "metal"))]
247 {
248 total += self.metal_devices.len();
249 }
250 total
251 }
252 }
253 }
254
255 pub fn get_capabilities(&self) -> SystemCapabilities {
257 let cuda_available = {
258 #[cfg(feature = "cuda")]
259 {
260 !self.cuda_devices.is_empty()
261 }
262 #[cfg(not(feature = "cuda"))]
263 {
264 false
265 }
266 };
267
268 let opencl_available = {
269 #[cfg(feature = "opencl")]
270 {
271 !self.opencl_devices.is_empty()
272 }
273 #[cfg(not(feature = "opencl"))]
274 {
275 false
276 }
277 };
278
279 let metal_available = {
280 #[cfg(all(target_os = "macos", feature = "metal"))]
281 {
282 !self.metal_devices.is_empty()
283 }
284 #[cfg(not(all(target_os = "macos", feature = "metal")))]
285 {
286 false
287 }
288 };
289
290 let gpu_available = cuda_available || opencl_available || metal_available;
291
292 let mut total_memory_mb = 0;
294 let mut max_compute_units = 0;
295
296 #[cfg(feature = "cuda")]
297 {
298 for device in &self.cuda_devices {
299 total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
300 if let Some(mp_count) = device.multiprocessor_count {
301 max_compute_units = max_compute_units.max(mp_count as u32);
302 }
303 }
304 }
305
306 #[cfg(feature = "opencl")]
307 {
308 for device in &self.opencl_devices {
309 total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
310 if let Some(mp_count) = device.multiprocessor_count {
311 max_compute_units = max_compute_units.max(mp_count as u32);
312 }
313 }
314 }
315
316 #[cfg(all(target_os = "macos", feature = "metal"))]
317 {
318 for device in &self.metal_devices {
319 total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
320 if let Some(mp_count) = device.multiprocessor_count {
321 max_compute_units = max_compute_units.max(mp_count as u32);
322 }
323 }
324 }
325
326 SystemCapabilities {
327 cuda_available,
328 opencl_available,
329 metal_available,
330 gpu_available,
331 gpu_memory_mb: total_memory_mb,
332 compute_units: max_compute_units,
333 }
334 }
335}
336
337static DEVICE_MANAGER: OnceLock<Arc<Mutex<DeviceManager>>> = OnceLock::new();
339
340#[allow(dead_code)]
342pub fn get_device_manager() -> NdimageResult<Arc<Mutex<DeviceManager>>> {
343 let result = DEVICE_MANAGER.get_or_init(|| {
344 match DeviceManager::new() {
345 Ok(manager) => Arc::new(Mutex::new(manager)),
346 Err(_) => {
347 Arc::new(Mutex::new(DeviceManager {
349 #[cfg(feature = "cuda")]
350 cuda_devices: Vec::new(),
351 #[cfg(feature = "opencl")]
352 opencl_devices: Vec::new(),
353 #[cfg(all(target_os = "macos", feature = "metal"))]
354 metal_devices: Vec::new(),
355 }))
356 }
357 }
358 });
359 Ok(result.clone())
360}
361
362#[cfg(feature = "cuda")]
364#[allow(dead_code)]
365fn detect_cuda_devices() -> NdimageResult<Vec<DeviceCapability>> {
366 let cuda_available = std::path::Path::new("/usr/local/cuda/lib64/libcudart.so").exists()
372 || std::path::Path::new("/usr/lib/x86_64-linux-gnu/libcudart.so").exists()
373 || std::env::var("CUDA_PATH").is_ok();
374
375 if !cuda_available {
376 return Ok(Vec::new());
377 }
378
379 let mut devices = Vec::new();
382
383 if let Ok(output) = std::process::Command::new("nvidia-smi")
385 .arg("--query-gpu=name,memory.total,memory.free")
386 .arg("--format=csv,noheader,nounits")
387 .output()
388 {
389 if output.status.success() {
390 let output_str = String::from_utf8_lossy(&output.stdout);
391 for (i, line) in output_str.lines().enumerate() {
392 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
393 if parts.len() >= 3 {
394 let name = parts[0].to_string();
395 let total_memory = parts[1].parse::<usize>().unwrap_or(0) * 1024 * 1024; let available_memory = parts[2].parse::<usize>().unwrap_or(0) * 1024 * 1024; let (compute_capability, multiprocessor_count, clock_rate) =
400 estimate_gpu_capabilities(&name);
401
402 let memory_bandwidth = estimate_memory_bandwidth(&name);
403
404 let capability = DeviceCapability {
405 name: format!("{} (CUDA Device {})", name, i),
406 total_memory,
407 available_memory,
408 compute_capability,
409 max_threads_per_block: Some(1024),
410 max_block_dims: Some([1024, 1024, 64]),
411 max_grid_dims: Some([65535, 65535, 65535]),
412 shared_memory_per_block: Some(49152), multiprocessor_count,
414 clock_rate,
415 memory_bandwidth,
416 };
417
418 devices.push(capability);
419 }
420 }
421 }
422 }
423
424 if devices.is_empty() {
426 devices.push(DeviceCapability {
427 name: "Generic CUDA Device".to_string(),
428 total_memory: 8_589_934_592, available_memory: 7_516_192_768, compute_capability: Some((7, 5)), max_threads_per_block: Some(1024),
432 max_block_dims: Some([1024, 1024, 64]),
433 max_grid_dims: Some([65535, 65535, 65535]),
434 shared_memory_per_block: Some(49152),
435 multiprocessor_count: Some(68),
436 clock_rate: Some(1_800_000), memory_bandwidth: Some(448.0), });
439 }
440
441 Ok(devices)
442}
443
444#[cfg(feature = "cuda")]
445#[allow(dead_code)]
446fn estimate_gpu_capabilities(name: &str) -> (Option<(u32, u32)>, Option<usize>, Option<usize>) {
447 let name_lower = name.to_lowercase();
448
449 if name_lower.contains("rtx 40") || name_lower.contains("ada lovelace") {
451 (Some((8, 9)), Some(128), Some(2_500_000))
453 } else if name_lower.contains("rtx 30") || name_lower.contains("ampere") {
454 (Some((8, 6)), Some(104), Some(1_700_000))
456 } else if name_lower.contains("rtx 20") || name_lower.contains("turing") {
457 (Some((7, 5)), Some(72), Some(1_500_000))
459 } else if name_lower.contains("gtx 16") || name_lower.contains("gtx 10") {
460 (Some((6, 1)), Some(20), Some(1_400_000))
462 } else if name_lower.contains("tesla") || name_lower.contains("quadro") {
463 (Some((7, 0)), Some(80), Some(1_300_000))
465 } else {
466 (Some((6, 0)), Some(32), Some(1_000_000))
468 }
469}
470
471#[cfg(feature = "cuda")]
472#[allow(dead_code)]
473fn estimate_memory_bandwidth(name: &str) -> Option<f64> {
474 let name_lower = name.to_lowercase();
475
476 if name_lower.contains("rtx 4090") {
477 Some(1008.0)
478 } else if name_lower.contains("rtx 4080") {
479 Some(717.0)
480 } else if name_lower.contains("rtx 3090") {
481 Some(936.0)
482 } else if name_lower.contains("rtx 3080") {
483 Some(760.0)
484 } else if name_lower.contains("rtx 3070") {
485 Some(448.0)
486 } else if name_lower.contains("rtx 2080") {
487 Some(448.0)
488 } else if name_lower.contains("tesla v100") {
489 Some(900.0)
490 } else if name_lower.contains("tesla a100") {
491 Some(1555.0)
492 } else {
493 Some(320.0) }
495}
496
497#[cfg(feature = "opencl")]
499#[allow(dead_code)]
500fn detect_opencl_devices() -> NdimageResult<Vec<DeviceCapability>> {
501 let opencl_available = std::path::Path::new("/usr/lib/x86_64-linux-gnu/libOpenCL.so.1")
506 .exists()
507 || std::path::Path::new("/usr/local/lib/libOpenCL.so").exists()
508 || std::env::var("OPENCL_ROOT").is_ok();
509
510 if !opencl_available {
511 return Ok(Vec::new());
512 }
513
514 let mut devices = Vec::new();
515
516 if let Ok(output) = std::process::Command::new("clinfo").arg("--list").output() {
518 if output.status.success() {
519 let output_str = String::from_utf8_lossy(&output.stdout);
520 for (i, line) in output_str.lines().enumerate() {
521 if line.contains("Device") && !line.contains("Platform") {
522 let device_name = line
523 .split("Device")
524 .nth(1)
525 .unwrap_or("Unknown OpenCL Device")
526 .trim()
527 .to_string();
528
529 let (memory_size, compute_units, clock_freq) =
531 estimate_opencl_capabilities(&device_name);
532
533 let capability = DeviceCapability {
534 name: format!("{} (OpenCL Device {})", device_name, i),
535 total_memory: memory_size,
536 available_memory: (memory_size as f64 * 0.8) as usize,
537 compute_capability: None, max_threads_per_block: Some(1024),
539 max_block_dims: Some([1024, 1024, 1024]),
540 max_grid_dims: None, shared_memory_per_block: Some(32768), multiprocessor_count: Some(compute_units),
543 clock_rate: Some(clock_freq),
544 memory_bandwidth: estimate_opencl_bandwidth(&device_name),
545 };
546
547 devices.push(capability);
548 }
549 }
550 }
551 }
552
553 if devices.is_empty() {
555 if std::path::Path::new("/sys/class/drm/card0").exists() {
557 devices.push(DeviceCapability {
558 name: "Intel Integrated Graphics (OpenCL)".to_string(),
559 total_memory: 2_147_483_648, available_memory: 1_717_986_918, compute_capability: None,
562 max_threads_per_block: Some(512),
563 max_block_dims: Some([512, 512, 512]),
564 max_grid_dims: None,
565 shared_memory_per_block: Some(32768),
566 multiprocessor_count: Some(24),
567 clock_rate: Some(1_000_000), memory_bandwidth: Some(25.6), });
570 }
571
572 if std::env::var("HSA_ENABLE_SDMA").is_ok() || std::path::Path::new("/opt/rocm").exists() {
574 devices.push(DeviceCapability {
575 name: "AMD Discrete Graphics (OpenCL)".to_string(),
576 total_memory: 8_589_934_592, available_memory: 6_871_947_674, compute_capability: None,
579 max_threads_per_block: Some(1024),
580 max_block_dims: Some([1024, 1024, 1024]),
581 max_grid_dims: None,
582 shared_memory_per_block: Some(65536), multiprocessor_count: Some(64),
584 clock_rate: Some(1_500_000), memory_bandwidth: Some(448.0), });
587 }
588 }
589
590 Ok(devices)
591}
592
593#[cfg(feature = "opencl")]
594#[allow(dead_code)]
595fn estimate_opencl_capabilities(name: &str) -> (usize, usize, usize) {
596 let name_lower = name.to_lowercase();
597
598 if name_lower.contains("intel") {
599 if name_lower.contains("iris") || name_lower.contains("xe") {
601 (4_294_967_296, 96, 1_300_000) } else {
603 (2_147_483_648, 24, 1_000_000) }
605 } else if name_lower.contains("amd") || name_lower.contains("radeon") {
606 if name_lower.contains("rx 7") || name_lower.contains("rx 6") {
608 (16_106_127_360, 80, 2_000_000) } else if name_lower.contains("rx 5") {
610 (8_589_934_592, 64, 1_800_000) } else {
612 (4_294_967_296, 36, 1_500_000) }
614 } else if name_lower.contains("nvidia")
615 || name_lower.contains("geforce")
616 || name_lower.contains("quadro")
617 {
618 if name_lower.contains("rtx") {
620 (12_884_901_888, 84, 1_700_000) } else {
622 (8_589_934_592, 56, 1_500_000) }
624 } else {
625 (2_147_483_648, 16, 1_000_000) }
628}
629
630#[cfg(feature = "opencl")]
631#[allow(dead_code)]
632fn estimate_opencl_bandwidth(name: &str) -> Option<f64> {
633 let name_lower = name.to_lowercase();
634
635 if name_lower.contains("intel iris") || name_lower.contains("intel xe") {
636 Some(68.0) } else if name_lower.contains("intel") {
638 Some(25.6) } else if name_lower.contains("rx 7") {
640 Some(960.0) } else if name_lower.contains("rx 6") {
642 Some(512.0) } else if name_lower.contains("rx 5") {
644 Some(448.0) } else if name_lower.contains("nvidia") {
646 Some(760.0) } else {
648 Some(100.0) }
650}
651
652#[cfg(all(target_os = "macos", feature = "metal"))]
654#[allow(dead_code)]
655fn detect_metal_devices() -> NdimageResult<Vec<DeviceCapability>> {
656 use std::ffi::{c_char, c_int, c_uint, c_ulong, c_void, CStr};
657 use std::ptr;
658
659 let mut devices = Vec::new();
664
665 if let Ok(gpu_info) = detect_macos_integrated_gpu() {
671 devices.push(gpu_info);
672 }
673
674 if let Ok(discrete_gpus) = detect_macos_discrete_gpus() {
676 devices.extend(discrete_gpus);
677 }
678
679 Ok(devices)
680}
681
682#[cfg(all(target_os = "macos", feature = "metal"))]
683#[allow(dead_code)]
684fn detect_macos_integrated_gpu() -> NdimageResult<DeviceCapability> {
685 use std::process::Command;
686
687 let output = Command::new("system_profiler")
689 .arg("SPDisplaysDataType")
690 .arg("-xml")
691 .output()
692 .map_err(|e| {
693 NdimageError::ComputationError(format!("Failed to run systemprofiler: {}", e))
694 })?;
695
696 if !output.status.success() {
697 return Err(NdimageError::ComputationError(
698 "system_profiler failed".into(),
699 ));
700 }
701
702 let output_str = String::from_utf8_lossy(&output.stdout);
703
704 let mut capability = DeviceCapability::default();
706
707 if output_str.contains("Intel") {
708 capability.name = "Intel Integrated Graphics (Metal)".to_string();
709 capability.total_memory = 1_073_741_824; capability.available_memory = 805_306_368; capability.multiprocessor_count = Some(16); capability.clock_rate = Some(1_000_000); capability.max_threads_per_block = Some(1024);
714 capability.max_block_dims = Some([1024, 1024, 64]);
715 capability.shared_memory_per_block = Some(32768); } else if output_str.contains("AMD") {
717 capability.name = "AMD Integrated Graphics (Metal)".to_string();
718 capability.total_memory = 2_147_483_648; capability.available_memory = 1_610_612_736; capability.multiprocessor_count = Some(32); capability.clock_rate = Some(1200_000); capability.max_threads_per_block = Some(1024);
723 capability.max_block_dims = Some([1024, 1024, 64]);
724 capability.shared_memory_per_block = Some(65536); } else {
726 capability.name = "Unknown Integrated Graphics (Metal)".to_string();
727 capability.total_memory = 1_073_741_824; capability.available_memory = 805_306_368; capability.multiprocessor_count = Some(8);
730 capability.clock_rate = Some(800_000); capability.max_threads_per_block = Some(512);
732 capability.max_block_dims = Some([512, 512, 64]);
733 capability.shared_memory_per_block = Some(16384); }
735
736 Ok(capability)
737}
738
739#[cfg(all(target_os = "macos", feature = "metal"))]
740#[allow(dead_code)]
741fn detect_macos_discrete_gpus() -> NdimageResult<Vec<DeviceCapability>> {
742 use std::process::Command;
743
744 let mut devices = Vec::new();
745
746 let output = Command::new("system_profiler")
748 .arg("SPDisplaysDataType")
749 .arg("-xml")
750 .output()
751 .map_err(|e| {
752 NdimageError::ComputationError(format!("Failed to run systemprofiler: {}", e))
753 })?;
754
755 if !output.status.success() {
756 return Ok(devices);
757 }
758
759 let output_str = String::from_utf8_lossy(&output.stdout);
760
761 if output_str.contains("Radeon") || output_str.contains("RX ") {
763 let mut capability = DeviceCapability::default();
764
765 if output_str.contains("RX 6800") || output_str.contains("RX 6900") {
766 capability.name = "AMD Radeon RX 6000 Series (Metal)".to_string();
767 capability.total_memory = 17_179_869_184; capability.available_memory = 15_032_385_536; capability.multiprocessor_count = Some(80);
770 capability.clock_rate = Some(2300_000); } else if output_str.contains("RX 5") {
772 capability.name = "AMD Radeon RX 5000 Series (Metal)".to_string();
773 capability.total_memory = 8_589_934_592; capability.available_memory = 7_516_192_768; capability.multiprocessor_count = Some(64);
776 capability.clock_rate = Some(1900_000); } else {
778 capability.name = "AMD Discrete Graphics (Metal)".to_string();
779 capability.total_memory = 4_294_967_296; capability.available_memory = 3_758_096_384; capability.multiprocessor_count = Some(32);
782 capability.clock_rate = Some(1_500_000); }
784
785 capability.max_threads_per_block = Some(1024);
786 capability.max_block_dims = Some([1024, 1024, 1024]);
787 capability.shared_memory_per_block = Some(65536); devices.push(capability);
790 }
791
792 if output_str.contains("Apple M") {
794 let mut capability = DeviceCapability::default();
795
796 if output_str.contains("M1 Advanced") {
797 capability.name = "Apple M1 Advanced GPU (Metal)".to_string();
798 capability.total_memory = 137_438_953_472; capability.available_memory = 120_259_084_288; capability.multiprocessor_count = Some(64); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1 Max") {
803 capability.name = "Apple M1 Max GPU (Metal)".to_string();
804 capability.total_memory = 68_719_476_736; capability.available_memory = 60_129_542_144; capability.multiprocessor_count = Some(32); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1 Pro") {
809 capability.name = "Apple M1 Pro GPU (Metal)".to_string();
810 capability.total_memory = 34_359_738_368; capability.available_memory = 30_064_771_072; capability.multiprocessor_count = Some(16); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1") {
815 capability.name = "Apple M1 GPU (Metal)".to_string();
816 capability.total_memory = 17_179_869_184; capability.available_memory = 15_032_385_536; capability.multiprocessor_count = Some(8); capability.clock_rate = Some(1300_000); } else if output_str.contains("M2") {
821 capability.name = "Apple M2 GPU (Metal)".to_string();
822 capability.total_memory = 25_769_803_776; capability.available_memory = 22_548_578_304; capability.multiprocessor_count = Some(10); capability.clock_rate = Some(1400_000); } else {
827 capability.name = "Apple Silicon GPU (Metal)".to_string();
828 capability.total_memory = 8_589_934_592; capability.available_memory = 7_516_192_768; capability.multiprocessor_count = Some(8);
831 capability.clock_rate = Some(1200_000); }
833
834 capability.max_threads_per_block = Some(1024);
835 capability.max_block_dims = Some([1024, 1024, 1024]);
836 capability.shared_memory_per_block = Some(32768); devices.push(capability);
839 }
840
841 Ok(devices)
842}
843
844pub struct MemoryManager {
846 memory_usage: HashMap<(super::Backend, usize), usize>,
848 memory_limits: HashMap<(super::Backend, usize), usize>,
850}
851
852impl MemoryManager {
853 pub fn new() -> Self {
854 Self {
855 memory_usage: HashMap::new(),
856 memory_limits: HashMap::new(),
857 }
858 }
859
860 pub fn can_allocate(&self, backend: super::Backend, deviceid: usize, size: usize) -> bool {
862 let key = (backend, deviceid);
863 let current_usage = self.memory_usage.get(&key).unwrap_or(&0);
864 let limit = self.memory_limits.get(&key).unwrap_or(&usize::MAX);
865
866 current_usage + size <= *limit
867 }
868
869 pub fn allocate(
871 &mut self,
872 backend: super::Backend,
873 device_id: usize,
874 size: usize,
875 ) -> NdimageResult<()> {
876 let key = (backend, device_id);
877
878 if !self.can_allocate(backend, device_id, size) {
879 return Err(NdimageError::ComputationError(
880 "Insufficient GPU memory for allocation".into(),
881 ));
882 }
883
884 *self.memory_usage.entry(key).or_insert(0) += size;
885 Ok(())
886 }
887
888 pub fn deallocate(&mut self, backend: super::Backend, deviceid: usize, size: usize) {
890 let key = (backend, deviceid);
891
892 if let Some(usage) = self.memory_usage.get_mut(&key) {
893 *usage = usage.saturating_sub(size);
894 }
895 }
896
897 pub fn set_memory_limit(&mut self, backend: super::Backend, deviceid: usize, limit: usize) {
899 self.memory_limits.insert((backend, deviceid), limit);
900 }
901
902 pub fn get_memory_usage(&self, backend: super::Backend, deviceid: usize) -> usize {
904 let key = (backend, deviceid);
905 *self.memory_usage.get(&key).unwrap_or(&0)
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912
913 #[test]
914 fn test_device_capability_default() {
915 let cap = DeviceCapability::default();
916 assert_eq!(cap.name, "Unknown Device");
917 assert_eq!(cap.total_memory, 0);
918 }
919
920 #[test]
921 fn test_memory_manager() {
922 let mut manager = MemoryManager::new();
923
924 manager
926 .allocate(super::super::Backend::Cpu, 0, 1000)
927 .expect("Operation failed");
928 assert_eq!(
929 manager.get_memory_usage(super::super::Backend::Cpu, 0),
930 1000
931 );
932
933 manager.deallocate(super::super::Backend::Cpu, 0, 500);
935 assert_eq!(manager.get_memory_usage(super::super::Backend::Cpu, 0), 500);
936
937 manager.set_memory_limit(super::super::Backend::Cpu, 0, 2000);
939 assert!(manager.can_allocate(super::super::Backend::Cpu, 0, 1000));
940 assert!(!manager.can_allocate(super::super::Backend::Cpu, 0, 2000));
941 }
942}