1use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use crate::error::{NdimageError, NdimageResult};
11
12#[derive(Debug, Clone)]
14pub struct DeviceCapability {
15 pub name: String,
17 pub total_memory: usize,
19 pub available_memory: usize,
21 pub compute_capability: Option<(u32, u32)>,
23 pub max_threads_per_block: Option<usize>,
25 pub max_block_dims: Option<[usize; 3]>,
27 pub max_grid_dims: Option<[usize; 3]>,
29 pub shared_memory_per_block: Option<usize>,
31 pub multiprocessor_count: Option<usize>,
33 pub clock_rate: Option<usize>,
35 pub memory_bandwidth: Option<f64>,
37}
38
39impl Default for DeviceCapability {
40 fn default() -> Self {
41 Self {
42 name: "Unknown Device".to_string(),
43 total_memory: 0,
44 available_memory: 0,
45 compute_capability: None,
46 max_threads_per_block: None,
47 max_block_dims: None,
48 max_grid_dims: None,
49 shared_memory_per_block: None,
50 multiprocessor_count: None,
51 clock_rate: None,
52 memory_bandwidth: None,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct SystemCapabilities {
60 pub cuda_available: bool,
61 pub opencl_available: bool,
62 pub metal_available: bool,
63 pub gpu_available: bool,
64 pub gpu_memory_mb: usize,
65 pub compute_units: u32,
66}
67
68pub struct DeviceManager {
70 #[cfg(feature = "cuda")]
71 cuda_devices: Vec<DeviceCapability>,
72 #[cfg(feature = "opencl")]
73 opencl_devices: Vec<DeviceCapability>,
74 #[cfg(all(target_os = "macos", feature = "metal"))]
75 metal_devices: Vec<DeviceCapability>,
76}
77
78impl DeviceManager {
79 pub fn new() -> NdimageResult<Self> {
81 let mut manager = Self {
82 #[cfg(feature = "cuda")]
83 cuda_devices: Vec::new(),
84 #[cfg(feature = "opencl")]
85 opencl_devices: Vec::new(),
86 #[cfg(all(target_os = "macos", feature = "metal"))]
87 metal_devices: Vec::new(),
88 };
89
90 #[cfg(feature = "cuda")]
92 {
93 manager.cuda_devices = detect_cuda_devices()?;
94 }
95
96 #[cfg(feature = "opencl")]
97 {
98 manager.opencl_devices = detect_opencl_devices()?;
99 }
100
101 #[cfg(all(target_os = "macos", feature = "metal"))]
102 {
103 manager.metal_devices = detect_metal_devices()?;
104 }
105
106 Ok(manager)
107 }
108
109 pub fn get_best_device(&self, requiredmemory: usize) -> Option<(super::Backend, usize)> {
111 let mut best_device = None;
112 let mut best_score = 0.0;
113
114 #[cfg(feature = "cuda")]
115 {
116 for (idx, device) in self.cuda_devices.iter().enumerate() {
117 if device.available_memory >= requiredmemory {
118 let score = self.calculate_device_score(device);
119 if score > best_score {
120 best_score = score;
121 best_device = Some((super::Backend::Cuda, idx));
122 }
123 }
124 }
125 }
126
127 #[cfg(feature = "opencl")]
128 {
129 for (idx, device) in self.opencl_devices.iter().enumerate() {
130 if device.available_memory >= requiredmemory {
131 let score = self.calculate_device_score(device) * 0.9; if score > best_score {
133 best_score = score;
134 best_device = Some((super::Backend::OpenCL, idx));
135 }
136 }
137 }
138 }
139
140 #[cfg(all(target_os = "macos", feature = "metal"))]
141 {
142 for (idx, device) in self.metal_devices.iter().enumerate() {
143 if device.available_memory >= requiredmemory {
144 let score = self.calculate_device_score(device) * 0.8; if score > best_score {
146 best_score = score;
147 best_device = Some((super::Backend::Metal, idx));
148 }
149 }
150 }
151 }
152
153 best_device
154 }
155
156 fn calculate_device_score(&self, device: &DeviceCapability) -> f64 {
158 let mut score = 0.0;
159
160 score += (device.total_memory as f64) / (1024.0 * 1024.0 * 1024.0) * 10.0;
162
163 if let Some(mp_count) = device.multiprocessor_count {
165 score += (mp_count as f64) * 5.0;
166 }
167
168 if let Some(clock) = device.clock_rate {
170 score += (clock as f64) / 1_000_000.0 * 3.0;
171 }
172
173 if let Some(bandwidth) = device.memory_bandwidth {
175 score += bandwidth * 0.1;
176 }
177
178 score
179 }
180
181 pub fn get_device_info(
183 &self,
184 backend: super::Backend,
185 device_id: usize,
186 ) -> Option<&DeviceCapability> {
187 match backend {
188 #[cfg(feature = "cuda")]
189 super::Backend::Cuda => self.cuda_devices.get(device_id),
190 #[cfg(feature = "opencl")]
191 super::Backend::OpenCL => self.opencl_devices.get(device_id),
192 #[cfg(all(target_os = "macos", feature = "metal"))]
193 super::Backend::Metal => self.metal_devices.get(device_id),
194 _ => None,
195 }
196 }
197
198 pub fn is_backend_available(&self, backend: super::Backend) -> bool {
200 match backend {
201 #[cfg(feature = "cuda")]
202 super::Backend::Cuda => !self.cuda_devices.is_empty(),
203 #[cfg(feature = "opencl")]
204 super::Backend::OpenCL => !self.opencl_devices.is_empty(),
205 #[cfg(all(target_os = "macos", feature = "metal"))]
206 super::Backend::Metal => !self.metal_devices.is_empty(),
207 super::Backend::Cpu => true,
208 super::Backend::Auto => {
209 #[cfg(feature = "cuda")]
210 if !self.cuda_devices.is_empty() {
211 return true;
212 }
213 #[cfg(feature = "opencl")]
214 if !self.opencl_devices.is_empty() {
215 return true;
216 }
217 #[cfg(all(target_os = "macos", feature = "metal"))]
218 if !self.metal_devices.is_empty() {
219 return true;
220 }
221 true }
223 }
224 }
225
226 pub fn device_count(&self, backend: super::Backend) -> usize {
228 match backend {
229 #[cfg(feature = "cuda")]
230 super::Backend::Cuda => self.cuda_devices.len(),
231 #[cfg(feature = "opencl")]
232 super::Backend::OpenCL => self.opencl_devices.len(),
233 #[cfg(all(target_os = "macos", feature = "metal"))]
234 super::Backend::Metal => self.metal_devices.len(),
235 super::Backend::Cpu => 1,
236 super::Backend::Auto => {
237 let mut total = 1; #[cfg(feature = "cuda")]
239 {
240 total += self.cuda_devices.len();
241 }
242 #[cfg(feature = "opencl")]
243 {
244 total += self.opencl_devices.len();
245 }
246 #[cfg(all(target_os = "macos", feature = "metal"))]
247 {
248 total += self.metal_devices.len();
249 }
250 total
251 }
252 }
253 }
254
255 pub fn get_capabilities(&self) -> SystemCapabilities {
257 let cuda_available = {
258 #[cfg(feature = "cuda")]
259 {
260 !self.cuda_devices.is_empty()
261 }
262 #[cfg(not(feature = "cuda"))]
263 {
264 false
265 }
266 };
267
268 let opencl_available = {
269 #[cfg(feature = "opencl")]
270 {
271 !self.opencl_devices.is_empty()
272 }
273 #[cfg(not(feature = "opencl"))]
274 {
275 false
276 }
277 };
278
279 let metal_available = {
280 #[cfg(all(target_os = "macos", feature = "metal"))]
281 {
282 !self.metal_devices.is_empty()
283 }
284 #[cfg(not(all(target_os = "macos", feature = "metal")))]
285 {
286 false
287 }
288 };
289
290 let gpu_available = cuda_available || opencl_available || metal_available;
291
292 let mut total_memory_mb = 0;
294 let mut max_compute_units = 0;
295
296 #[cfg(feature = "cuda")]
297 {
298 for device in &self.cuda_devices {
299 total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
300 if let Some(mp_count) = device.multiprocessor_count {
301 max_compute_units = max_compute_units.max(mp_count as u32);
302 }
303 }
304 }
305
306 #[cfg(feature = "opencl")]
307 {
308 for device in &self.opencl_devices {
309 total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
310 if let Some(mp_count) = device.multiprocessor_count {
311 max_compute_units = max_compute_units.max(mp_count as u32);
312 }
313 }
314 }
315
316 #[cfg(all(target_os = "macos", feature = "metal"))]
317 {
318 for device in &self.metal_devices {
319 total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
320 if let Some(mp_count) = device.multiprocessor_count {
321 max_compute_units = max_compute_units.max(mp_count as u32);
322 }
323 }
324 }
325
326 SystemCapabilities {
327 cuda_available,
328 opencl_available,
329 metal_available,
330 gpu_available,
331 gpu_memory_mb: total_memory_mb,
332 compute_units: max_compute_units,
333 }
334 }
335}
336
337static DEVICE_MANAGER: OnceLock<Arc<Mutex<DeviceManager>>> = OnceLock::new();
339
340#[allow(dead_code)]
342pub fn get_device_manager() -> NdimageResult<Arc<Mutex<DeviceManager>>> {
343 let result = DEVICE_MANAGER.get_or_init(|| {
344 match DeviceManager::new() {
345 Ok(manager) => Arc::new(Mutex::new(manager)),
346 Err(_) => {
347 Arc::new(Mutex::new(DeviceManager {
349 #[cfg(feature = "cuda")]
350 cuda_devices: Vec::new(),
351 #[cfg(feature = "opencl")]
352 opencl_devices: Vec::new(),
353 #[cfg(all(target_os = "macos", feature = "metal"))]
354 metal_devices: Vec::new(),
355 }))
356 }
357 }
358 });
359 Ok(result.clone())
360}
361
362#[cfg(feature = "cuda")]
364#[allow(dead_code)]
365fn detect_cuda_devices() -> NdimageResult<Vec<DeviceCapability>> {
366 let cuda_available = std::path::Path::new("/usr/local/cuda/lib64/libcudart.so").exists()
372 || std::path::Path::new("/usr/lib/x86_64-linux-gnu/libcudart.so").exists()
373 || std::env::var("CUDA_PATH").is_ok();
374
375 if !cuda_available {
376 return Ok(Vec::new());
377 }
378
379 let mut devices = Vec::new();
386
387 if let Ok(output) = std::process::Command::new("nvidia-smi")
389 .arg("--query-gpu=name,memory.total,memory.free")
390 .arg("--format=csv,noheader,nounits")
391 .output()
392 {
393 if output.status.success() {
394 let output_str = String::from_utf8_lossy(&output.stdout);
395 for (i, line) in output_str.lines().enumerate() {
396 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
397 if parts.len() >= 3 {
398 let name = parts[0].to_string();
399 let total_memory = parts[1].parse::<usize>().unwrap_or(0) * 1024 * 1024; let available_memory = parts[2].parse::<usize>().unwrap_or(0) * 1024 * 1024; let (compute_capability, multiprocessor_count, clock_rate) =
404 estimate_gpu_capabilities(&name);
405
406 let memory_bandwidth = estimate_memory_bandwidth(&name);
407
408 let capability = DeviceCapability {
409 name: format!("{} (CUDA Device {})", name, i),
410 total_memory,
411 available_memory,
412 compute_capability,
413 max_threads_per_block: Some(1024),
414 max_block_dims: Some([1024, 1024, 64]),
415 max_grid_dims: Some([65535, 65535, 65535]),
416 shared_memory_per_block: Some(49152), multiprocessor_count,
418 clock_rate,
419 memory_bandwidth,
420 };
421
422 devices.push(capability);
423 }
424 }
425 }
426 }
427
428 Ok(devices)
433}
434
435#[cfg(feature = "cuda")]
436#[allow(dead_code)]
437fn estimate_gpu_capabilities(name: &str) -> (Option<(u32, u32)>, Option<usize>, Option<usize>) {
438 let name_lower = name.to_lowercase();
439
440 if name_lower.contains("rtx 40") || name_lower.contains("ada lovelace") {
442 (Some((8, 9)), Some(128), Some(2_500_000))
444 } else if name_lower.contains("rtx 30") || name_lower.contains("ampere") {
445 (Some((8, 6)), Some(104), Some(1_700_000))
447 } else if name_lower.contains("rtx 20") || name_lower.contains("turing") {
448 (Some((7, 5)), Some(72), Some(1_500_000))
450 } else if name_lower.contains("gtx 16") || name_lower.contains("gtx 10") {
451 (Some((6, 1)), Some(20), Some(1_400_000))
453 } else if name_lower.contains("tesla") || name_lower.contains("quadro") {
454 (Some((7, 0)), Some(80), Some(1_300_000))
456 } else {
457 (Some((6, 0)), Some(32), Some(1_000_000))
459 }
460}
461
462#[cfg(feature = "cuda")]
463#[allow(dead_code)]
464fn estimate_memory_bandwidth(name: &str) -> Option<f64> {
465 let name_lower = name.to_lowercase();
466
467 if name_lower.contains("rtx 4090") {
468 Some(1008.0)
469 } else if name_lower.contains("rtx 4080") {
470 Some(717.0)
471 } else if name_lower.contains("rtx 3090") {
472 Some(936.0)
473 } else if name_lower.contains("rtx 3080") {
474 Some(760.0)
475 } else if name_lower.contains("rtx 3070") {
476 Some(448.0)
477 } else if name_lower.contains("rtx 2080") {
478 Some(448.0)
479 } else if name_lower.contains("tesla v100") {
480 Some(900.0)
481 } else if name_lower.contains("tesla a100") {
482 Some(1555.0)
483 } else {
484 Some(320.0) }
486}
487
488#[cfg(feature = "opencl")]
490#[allow(dead_code)]
491fn detect_opencl_devices() -> NdimageResult<Vec<DeviceCapability>> {
492 let opencl_available = std::path::Path::new("/usr/lib/x86_64-linux-gnu/libOpenCL.so.1")
497 .exists()
498 || std::path::Path::new("/usr/local/lib/libOpenCL.so").exists()
499 || std::env::var("OPENCL_ROOT").is_ok();
500
501 if !opencl_available {
502 return Ok(Vec::new());
503 }
504
505 let mut devices = Vec::new();
506
507 if let Ok(output) = std::process::Command::new("clinfo").arg("--list").output() {
509 if output.status.success() {
510 let output_str = String::from_utf8_lossy(&output.stdout);
511 for (i, line) in output_str.lines().enumerate() {
512 if line.contains("Device") && !line.contains("Platform") {
513 let device_name = line
514 .split("Device")
515 .nth(1)
516 .unwrap_or("Unknown OpenCL Device")
517 .trim()
518 .to_string();
519
520 let (memory_size, compute_units, clock_freq) =
522 estimate_opencl_capabilities(&device_name);
523
524 let capability = DeviceCapability {
525 name: format!("{} (OpenCL Device {})", device_name, i),
526 total_memory: memory_size,
527 available_memory: (memory_size as f64 * 0.8) as usize,
528 compute_capability: None, max_threads_per_block: Some(1024),
530 max_block_dims: Some([1024, 1024, 1024]),
531 max_grid_dims: None, shared_memory_per_block: Some(32768), multiprocessor_count: Some(compute_units),
534 clock_rate: Some(clock_freq),
535 memory_bandwidth: estimate_opencl_bandwidth(&device_name),
536 };
537
538 devices.push(capability);
539 }
540 }
541 }
542 }
543
544 if devices.is_empty() {
546 if std::path::Path::new("/sys/class/drm/card0").exists() {
548 devices.push(DeviceCapability {
549 name: "Intel Integrated Graphics (OpenCL)".to_string(),
550 total_memory: 2_147_483_648, available_memory: 1_717_986_918, compute_capability: None,
553 max_threads_per_block: Some(512),
554 max_block_dims: Some([512, 512, 512]),
555 max_grid_dims: None,
556 shared_memory_per_block: Some(32768),
557 multiprocessor_count: Some(24),
558 clock_rate: Some(1_000_000), memory_bandwidth: Some(25.6), });
561 }
562
563 if std::env::var("HSA_ENABLE_SDMA").is_ok() || std::path::Path::new("/opt/rocm").exists() {
565 devices.push(DeviceCapability {
566 name: "AMD Discrete Graphics (OpenCL)".to_string(),
567 total_memory: 8_589_934_592, available_memory: 6_871_947_674, compute_capability: None,
570 max_threads_per_block: Some(1024),
571 max_block_dims: Some([1024, 1024, 1024]),
572 max_grid_dims: None,
573 shared_memory_per_block: Some(65536), multiprocessor_count: Some(64),
575 clock_rate: Some(1_500_000), memory_bandwidth: Some(448.0), });
578 }
579 }
580
581 Ok(devices)
582}
583
584#[cfg(feature = "opencl")]
585#[allow(dead_code)]
586fn estimate_opencl_capabilities(name: &str) -> (usize, usize, usize) {
587 let name_lower = name.to_lowercase();
588
589 if name_lower.contains("intel") {
590 if name_lower.contains("iris") || name_lower.contains("xe") {
592 (4_294_967_296, 96, 1_300_000) } else {
594 (2_147_483_648, 24, 1_000_000) }
596 } else if name_lower.contains("amd") || name_lower.contains("radeon") {
597 if name_lower.contains("rx 7") || name_lower.contains("rx 6") {
599 (16_106_127_360, 80, 2_000_000) } else if name_lower.contains("rx 5") {
601 (8_589_934_592, 64, 1_800_000) } else {
603 (4_294_967_296, 36, 1_500_000) }
605 } else if name_lower.contains("nvidia")
606 || name_lower.contains("geforce")
607 || name_lower.contains("quadro")
608 {
609 if name_lower.contains("rtx") {
611 (12_884_901_888, 84, 1_700_000) } else {
613 (8_589_934_592, 56, 1_500_000) }
615 } else {
616 (2_147_483_648, 16, 1_000_000) }
619}
620
621#[cfg(feature = "opencl")]
622#[allow(dead_code)]
623fn estimate_opencl_bandwidth(name: &str) -> Option<f64> {
624 let name_lower = name.to_lowercase();
625
626 if name_lower.contains("intel iris") || name_lower.contains("intel xe") {
627 Some(68.0) } else if name_lower.contains("intel") {
629 Some(25.6) } else if name_lower.contains("rx 7") {
631 Some(960.0) } else if name_lower.contains("rx 6") {
633 Some(512.0) } else if name_lower.contains("rx 5") {
635 Some(448.0) } else if name_lower.contains("nvidia") {
637 Some(760.0) } else {
639 Some(100.0) }
641}
642
643#[cfg(all(target_os = "macos", feature = "metal"))]
645#[allow(dead_code)]
646fn detect_metal_devices() -> NdimageResult<Vec<DeviceCapability>> {
647 use std::ffi::{c_char, c_int, c_uint, c_ulong, c_void, CStr};
648 use std::ptr;
649
650 let mut devices = Vec::new();
655
656 if let Ok(gpu_info) = detect_macos_integrated_gpu() {
662 devices.push(gpu_info);
663 }
664
665 if let Ok(discrete_gpus) = detect_macos_discrete_gpus() {
667 devices.extend(discrete_gpus);
668 }
669
670 Ok(devices)
671}
672
673#[cfg(all(target_os = "macos", feature = "metal"))]
674#[allow(dead_code)]
675fn detect_macos_integrated_gpu() -> NdimageResult<DeviceCapability> {
676 use std::process::Command;
677
678 let output = Command::new("system_profiler")
680 .arg("SPDisplaysDataType")
681 .arg("-xml")
682 .output()
683 .map_err(|e| {
684 NdimageError::ComputationError(format!("Failed to run systemprofiler: {}", e))
685 })?;
686
687 if !output.status.success() {
688 return Err(NdimageError::ComputationError(
689 "system_profiler failed".into(),
690 ));
691 }
692
693 let output_str = String::from_utf8_lossy(&output.stdout);
694
695 let mut capability = DeviceCapability::default();
697
698 if output_str.contains("Intel") {
699 capability.name = "Intel Integrated Graphics (Metal)".to_string();
700 capability.total_memory = 1_073_741_824; capability.available_memory = 805_306_368; capability.multiprocessor_count = Some(16); capability.clock_rate = Some(1_000_000); capability.max_threads_per_block = Some(1024);
705 capability.max_block_dims = Some([1024, 1024, 64]);
706 capability.shared_memory_per_block = Some(32768); } else if output_str.contains("AMD") {
708 capability.name = "AMD Integrated Graphics (Metal)".to_string();
709 capability.total_memory = 2_147_483_648; capability.available_memory = 1_610_612_736; capability.multiprocessor_count = Some(32); capability.clock_rate = Some(1200_000); capability.max_threads_per_block = Some(1024);
714 capability.max_block_dims = Some([1024, 1024, 64]);
715 capability.shared_memory_per_block = Some(65536); } else {
717 capability.name = "Unknown Integrated Graphics (Metal)".to_string();
718 capability.total_memory = 1_073_741_824; capability.available_memory = 805_306_368; capability.multiprocessor_count = Some(8);
721 capability.clock_rate = Some(800_000); capability.max_threads_per_block = Some(512);
723 capability.max_block_dims = Some([512, 512, 64]);
724 capability.shared_memory_per_block = Some(16384); }
726
727 Ok(capability)
728}
729
730#[cfg(all(target_os = "macos", feature = "metal"))]
731#[allow(dead_code)]
732fn detect_macos_discrete_gpus() -> NdimageResult<Vec<DeviceCapability>> {
733 use std::process::Command;
734
735 let mut devices = Vec::new();
736
737 let output = Command::new("system_profiler")
739 .arg("SPDisplaysDataType")
740 .arg("-xml")
741 .output()
742 .map_err(|e| {
743 NdimageError::ComputationError(format!("Failed to run systemprofiler: {}", e))
744 })?;
745
746 if !output.status.success() {
747 return Ok(devices);
748 }
749
750 let output_str = String::from_utf8_lossy(&output.stdout);
751
752 if output_str.contains("Radeon") || output_str.contains("RX ") {
754 let mut capability = DeviceCapability::default();
755
756 if output_str.contains("RX 6800") || output_str.contains("RX 6900") {
757 capability.name = "AMD Radeon RX 6000 Series (Metal)".to_string();
758 capability.total_memory = 17_179_869_184; capability.available_memory = 15_032_385_536; capability.multiprocessor_count = Some(80);
761 capability.clock_rate = Some(2300_000); } else if output_str.contains("RX 5") {
763 capability.name = "AMD Radeon RX 5000 Series (Metal)".to_string();
764 capability.total_memory = 8_589_934_592; capability.available_memory = 7_516_192_768; capability.multiprocessor_count = Some(64);
767 capability.clock_rate = Some(1900_000); } else {
769 capability.name = "AMD Discrete Graphics (Metal)".to_string();
770 capability.total_memory = 4_294_967_296; capability.available_memory = 3_758_096_384; capability.multiprocessor_count = Some(32);
773 capability.clock_rate = Some(1_500_000); }
775
776 capability.max_threads_per_block = Some(1024);
777 capability.max_block_dims = Some([1024, 1024, 1024]);
778 capability.shared_memory_per_block = Some(65536); devices.push(capability);
781 }
782
783 if output_str.contains("Apple M") {
785 let mut capability = DeviceCapability::default();
786
787 if output_str.contains("M1 Advanced") {
788 capability.name = "Apple M1 Advanced GPU (Metal)".to_string();
789 capability.total_memory = 137_438_953_472; capability.available_memory = 120_259_084_288; capability.multiprocessor_count = Some(64); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1 Max") {
794 capability.name = "Apple M1 Max GPU (Metal)".to_string();
795 capability.total_memory = 68_719_476_736; capability.available_memory = 60_129_542_144; capability.multiprocessor_count = Some(32); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1 Pro") {
800 capability.name = "Apple M1 Pro GPU (Metal)".to_string();
801 capability.total_memory = 34_359_738_368; capability.available_memory = 30_064_771_072; capability.multiprocessor_count = Some(16); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1") {
806 capability.name = "Apple M1 GPU (Metal)".to_string();
807 capability.total_memory = 17_179_869_184; capability.available_memory = 15_032_385_536; capability.multiprocessor_count = Some(8); capability.clock_rate = Some(1300_000); } else if output_str.contains("M2") {
812 capability.name = "Apple M2 GPU (Metal)".to_string();
813 capability.total_memory = 25_769_803_776; capability.available_memory = 22_548_578_304; capability.multiprocessor_count = Some(10); capability.clock_rate = Some(1400_000); } else {
818 capability.name = "Apple Silicon GPU (Metal)".to_string();
819 capability.total_memory = 8_589_934_592; capability.available_memory = 7_516_192_768; capability.multiprocessor_count = Some(8);
822 capability.clock_rate = Some(1200_000); }
824
825 capability.max_threads_per_block = Some(1024);
826 capability.max_block_dims = Some([1024, 1024, 1024]);
827 capability.shared_memory_per_block = Some(32768); devices.push(capability);
830 }
831
832 Ok(devices)
833}
834
835pub struct MemoryManager {
837 memory_usage: HashMap<(super::Backend, usize), usize>,
839 memory_limits: HashMap<(super::Backend, usize), usize>,
841}
842
843impl MemoryManager {
844 pub fn new() -> Self {
845 Self {
846 memory_usage: HashMap::new(),
847 memory_limits: HashMap::new(),
848 }
849 }
850
851 pub fn can_allocate(&self, backend: super::Backend, deviceid: usize, size: usize) -> bool {
853 let key = (backend, deviceid);
854 let current_usage = self.memory_usage.get(&key).unwrap_or(&0);
855 let limit = self.memory_limits.get(&key).unwrap_or(&usize::MAX);
856
857 current_usage + size <= *limit
858 }
859
860 pub fn allocate(
862 &mut self,
863 backend: super::Backend,
864 device_id: usize,
865 size: usize,
866 ) -> NdimageResult<()> {
867 let key = (backend, device_id);
868
869 if !self.can_allocate(backend, device_id, size) {
870 return Err(NdimageError::ComputationError(
871 "Insufficient GPU memory for allocation".into(),
872 ));
873 }
874
875 *self.memory_usage.entry(key).or_insert(0) += size;
876 Ok(())
877 }
878
879 pub fn deallocate(&mut self, backend: super::Backend, deviceid: usize, size: usize) {
881 let key = (backend, deviceid);
882
883 if let Some(usage) = self.memory_usage.get_mut(&key) {
884 *usage = usage.saturating_sub(size);
885 }
886 }
887
888 pub fn set_memory_limit(&mut self, backend: super::Backend, deviceid: usize, limit: usize) {
890 self.memory_limits.insert((backend, deviceid), limit);
891 }
892
893 pub fn get_memory_usage(&self, backend: super::Backend, deviceid: usize) -> usize {
895 let key = (backend, deviceid);
896 *self.memory_usage.get(&key).unwrap_or(&0)
897 }
898}
899
900#[cfg(test)]
901mod tests {
902 use super::*;
903
904 #[test]
905 fn test_device_capability_default() {
906 let cap = DeviceCapability::default();
907 assert_eq!(cap.name, "Unknown Device");
908 assert_eq!(cap.total_memory, 0);
909 }
910
911 #[test]
912 fn test_memory_manager() {
913 let mut manager = MemoryManager::new();
914
915 manager
917 .allocate(super::super::Backend::Cpu, 0, 1000)
918 .expect("Operation failed");
919 assert_eq!(
920 manager.get_memory_usage(super::super::Backend::Cpu, 0),
921 1000
922 );
923
924 manager.deallocate(super::super::Backend::Cpu, 0, 500);
926 assert_eq!(manager.get_memory_usage(super::super::Backend::Cpu, 0), 500);
927
928 manager.set_memory_limit(super::super::Backend::Cpu, 0, 2000);
930 assert!(manager.can_allocate(super::super::Backend::Cpu, 0, 1000));
931 assert!(!manager.can_allocate(super::super::Backend::Cpu, 0, 2000));
932 }
933}