1use std::fmt;
2
3use log::debug;
4#[cfg(all(feature = "opencl", feature = "cuda"))]
5use log::warn;
6use once_cell::sync::Lazy;
7
8use std::convert::TryFrom;
9use std::mem;
10
11use crate::error::{GPUError, GPUResult};
12
13#[cfg(feature = "cuda")]
14use crate::cuda;
15#[cfg(feature = "opencl")]
16use crate::opencl;
17
18const UUID_SIZE: usize = 16;
20const AMD_DEVICE_VENDOR_STRING: &str = "Advanced Micro Devices, Inc.";
21const AMD_DEVICE_VENDOR_ID: u32 = 0x1002;
22const AMD_DEVICE_ON_APPLE_VENDOR_STRING: &str = "AMD";
24const AMD_DEVICE_ON_APPLE_VENDOR_ID: u32 = 0x1021d00;
25const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation";
26const NVIDIA_DEVICE_VENDOR_ID: u32 = 0x10de;
27
28#[cfg(feature = "cuda")]
31static DEVICES: Lazy<(Vec<Device>, cuda::utils::CudaContexts)> = Lazy::new(build_device_list);
32
33#[cfg(all(feature = "opencl", not(feature = "cuda")))]
36static DEVICES: Lazy<(Vec<Device>, ())> = Lazy::new(build_device_list);
37
38#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq)]
48pub struct PciId(u16);
49
50impl From<u16> for PciId {
51 fn from(id: u16) -> Self {
52 Self(id)
53 }
54}
55
56impl From<PciId> for u16 {
57 fn from(id: PciId) -> Self {
58 id.0
59 }
60}
61
62impl TryFrom<&str> for PciId {
64 type Error = GPUError;
65
66 fn try_from(pci_id: &str) -> GPUResult<Self> {
67 let mut bytes = [0; mem::size_of::<u16>()];
68 hex::decode_to_slice(pci_id.replace(':', ""), &mut bytes).map_err(|_| {
69 GPUError::InvalidId(format!(
70 "Cannot parse PCI ID, expected hex-encoded string formated as aa:bb, got {0}.",
71 pci_id
72 ))
73 })?;
74 let parsed = u16::from_be_bytes(bytes);
75 Ok(Self(parsed))
76 }
77}
78
79impl fmt::Display for PciId {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 let bytes = u16::to_be_bytes(self.0);
83 write!(f, "{:02x}:{:02x}", bytes[0], bytes[1])
84 }
85}
86
87#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)]
89pub struct DeviceUuid([u8; UUID_SIZE]);
90
91impl From<[u8; UUID_SIZE]> for DeviceUuid {
92 fn from(uuid: [u8; UUID_SIZE]) -> Self {
93 Self(uuid)
94 }
95}
96
97impl From<DeviceUuid> for [u8; UUID_SIZE] {
98 fn from(uuid: DeviceUuid) -> Self {
99 uuid.0
100 }
101}
102
103impl TryFrom<&str> for DeviceUuid {
106 type Error = GPUError;
107
108 fn try_from(uuid: &str) -> GPUResult<Self> {
109 let mut bytes = [0; UUID_SIZE];
110 hex::decode_to_slice(uuid.replace('-', ""), &mut bytes)
111 .map_err(|_| {
112 GPUError::InvalidId(format!("Cannot parse UUID, expected hex-encoded string formated as aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee, got {0}.", uuid))
113 })?;
114 Ok(Self(bytes))
115 }
116}
117
118impl fmt::Display for DeviceUuid {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 write!(
123 f,
124 "{}-{}-{}-{}-{}",
125 hex::encode(&self.0[..4]),
126 hex::encode(&self.0[4..6]),
127 hex::encode(&self.0[6..8]),
128 hex::encode(&self.0[8..10]),
129 hex::encode(&self.0[10..])
130 )
131 }
132}
133
134impl fmt::Debug for DeviceUuid {
135 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136 write!(f, "{}", self)
137 }
138}
139
140#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
142pub enum UniqueId {
143 PciId(PciId),
145 Uuid(DeviceUuid),
147}
148
149impl TryFrom<&str> for UniqueId {
151 type Error = GPUError;
152
153 fn try_from(unique_id: &str) -> GPUResult<Self> {
154 Ok(match unique_id.contains('-') {
155 true => Self::Uuid(DeviceUuid::try_from(unique_id)?),
156 false => Self::PciId(PciId::try_from(unique_id)?),
157 })
158 }
159}
160
161impl fmt::Display for UniqueId {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 match self {
164 Self::PciId(id) => id.fmt(f),
165 Self::Uuid(id) => id.fmt(f),
166 }
167 }
168}
169
170#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
172pub enum Vendor {
173 Amd,
175 Nvidia,
177}
178
179impl TryFrom<&str> for Vendor {
180 type Error = GPUError;
181
182 fn try_from(vendor: &str) -> GPUResult<Self> {
183 match vendor {
184 AMD_DEVICE_VENDOR_STRING => Ok(Self::Amd),
185 AMD_DEVICE_ON_APPLE_VENDOR_STRING => Ok(Self::Amd),
186 NVIDIA_DEVICE_VENDOR_STRING => Ok(Self::Nvidia),
187 _ => Err(GPUError::UnsupportedVendor(vendor.to_string())),
188 }
189 }
190}
191
192impl TryFrom<u32> for Vendor {
193 type Error = GPUError;
194
195 fn try_from(vendor: u32) -> GPUResult<Self> {
196 match vendor {
197 AMD_DEVICE_VENDOR_ID => Ok(Self::Amd),
198 AMD_DEVICE_ON_APPLE_VENDOR_ID => Ok(Self::Amd),
199 NVIDIA_DEVICE_VENDOR_ID => Ok(Self::Nvidia),
200 _ => Err(GPUError::UnsupportedVendor(format!("0x{:x}", vendor))),
201 }
202 }
203}
204
205impl fmt::Display for Vendor {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 let vendor = match self {
208 Self::Amd => AMD_DEVICE_VENDOR_STRING,
209 Self::Nvidia => NVIDIA_DEVICE_VENDOR_STRING,
210 };
211 write!(f, "{}", vendor)
212 }
213}
214
215#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
217pub enum Framework {
218 #[cfg(feature = "cuda")]
220 Cuda,
221 #[cfg(feature = "opencl")]
223 Opencl,
224}
225
226#[derive(Clone, Debug, Eq, Hash, PartialEq)]
228pub struct Device {
229 vendor: Vendor,
230 name: String,
231 memory: u64,
232 compute_units: u32,
233 compute_capability: Option<(u32, u32)>,
235 pci_id: PciId,
237 uuid: Option<DeviceUuid>,
238 #[cfg(feature = "cuda")]
239 cuda: Option<cuda::Device>,
240 #[cfg(feature = "opencl")]
241 opencl: Option<opencl::Device>,
242}
243
244impl Device {
245 pub fn vendor(&self) -> Vendor {
247 self.vendor
248 }
249
250 pub fn name(&self) -> String {
252 self.name.clone()
253 }
254
255 pub fn memory(&self) -> u64 {
257 self.memory
258 }
259
260 pub fn compute_units(&self) -> u32 {
262 self.compute_units
263 }
264
265 pub fn compute_capability(&self) -> Option<(u32, u32)> {
268 self.compute_capability
269 }
270
271 pub fn unique_id(&self) -> UniqueId {
273 match self.uuid {
274 Some(uuid) => UniqueId::Uuid(uuid),
275 None => UniqueId::PciId(self.pci_id),
276 }
277 }
278
279 pub fn framework(&self) -> Framework {
284 #[cfg(all(feature = "opencl", feature = "cuda"))]
285 if cfg!(feature = "cuda") && self.cuda.is_some() {
286 Framework::Cuda
287 } else {
288 Framework::Opencl
289 }
290
291 #[cfg(all(feature = "cuda", not(feature = "opencl")))]
292 {
293 Framework::Cuda
294 }
295
296 #[cfg(all(feature = "opencl", not(feature = "cuda")))]
297 {
298 Framework::Opencl
299 }
300 }
301
302 #[cfg(feature = "cuda")]
304 pub fn cuda_device(&self) -> Option<&cuda::Device> {
305 self.cuda.as_ref()
306 }
307
308 #[cfg(feature = "opencl")]
310 pub fn opencl_device(&self) -> Option<&opencl::Device> {
311 self.opencl.as_ref()
312 }
313
314 pub fn all() -> Vec<&'static Device> {
316 Self::all_iter().collect()
317 }
318
319 pub fn by_pci_id(pci_id: PciId) -> Option<&'static Device> {
321 Self::all_iter().find(|d| pci_id == d.pci_id)
322 }
323
324 pub fn by_uuid(uuid: DeviceUuid) -> Option<&'static Device> {
326 Self::all_iter().find(|d| Some(uuid) == d.uuid)
327 }
328
329 pub fn by_unique_id(unique_id: UniqueId) -> Option<&'static Device> {
331 Self::all_iter().find(|d| unique_id == d.unique_id())
332 }
333
334 fn all_iter() -> impl Iterator<Item = &'static Device> {
336 DEVICES.0.iter()
337 }
338}
339
340#[cfg(feature = "cuda")]
349fn build_device_list() -> (Vec<Device>, cuda::utils::CudaContexts) {
350 let mut all_devices = Vec::new();
351
352 #[cfg(feature = "opencl")]
353 let opencl_devices = opencl::utils::build_device_list();
354
355 #[cfg(all(feature = "cuda", feature = "opencl"))]
356 let (mut cuda_devices, cuda_contexts) = cuda::utils::build_device_list();
357 #[cfg(all(feature = "cuda", not(feature = "opencl")))]
358 let (cuda_devices, cuda_contexts) = cuda::utils::build_device_list();
359
360 #[cfg(feature = "opencl")]
362 for opencl_device in opencl_devices {
363 let mut device = Device {
364 vendor: opencl_device.vendor(),
365 name: opencl_device.name(),
366 memory: opencl_device.memory(),
367 compute_units: opencl_device.compute_units(),
368 compute_capability: opencl_device.compute_capability(),
369 pci_id: opencl_device.pci_id(),
370 uuid: opencl_device.uuid(),
371 opencl: Some(opencl_device),
372 cuda: None,
373 };
374
375 #[cfg(feature = "cuda")]
377 if device.vendor == Vendor::Nvidia {
378 for ii in 0..cuda_devices.len() {
379 if (device.uuid.is_some() && cuda_devices[ii].uuid() == device.uuid)
380 || (cuda_devices[ii].pci_id() == device.pci_id)
381 {
382 if device.memory() != cuda_devices[ii].memory() {
383 warn!("OpenCL and CUDA report different amounts of memory for a device with the same identifier");
384 break;
385 }
386 if device.compute_units() != cuda_devices[ii].compute_units() {
387 warn!("OpenCL and CUDA report different amounts of compute units for a device with the same identifier");
388 break;
389 }
390 device.cuda = Some(cuda_devices.remove(ii));
392 break;
394 }
395 }
396 }
397
398 all_devices.push(device)
399 }
400
401 for cuda_device in cuda_devices {
403 let device = Device {
404 vendor: cuda_device.vendor(),
405 name: cuda_device.name(),
406 memory: cuda_device.memory(),
407 compute_units: cuda_device.compute_units(),
408 compute_capability: Some(cuda_device.compute_capability()),
409 pci_id: cuda_device.pci_id(),
410 uuid: cuda_device.uuid(),
411 cuda: Some(cuda_device),
412 #[cfg(feature = "opencl")]
413 opencl: None,
414 };
415 all_devices.push(device);
416 }
417
418 debug!("loaded devices: {:?}", all_devices);
419 (all_devices, cuda_contexts)
420}
421
422#[cfg(all(feature = "opencl", not(feature = "cuda")))]
427fn build_device_list() -> (Vec<Device>, ()) {
428 let devices = opencl::utils::build_device_list()
429 .into_iter()
430 .map(|device| Device {
431 vendor: device.vendor(),
432 name: device.name(),
433 memory: device.memory(),
434 compute_units: device.compute_units(),
435 compute_capability: device.compute_capability(),
436 pci_id: device.pci_id(),
437 uuid: device.uuid(),
438 opencl: Some(device),
439 })
440 .collect();
441
442 debug!("loaded devices: {:?}", devices);
443 (devices, ())
444}
445
446#[cfg(test)]
447mod test {
448 use super::{
449 Device, DeviceUuid, GPUError, PciId, UniqueId, Vendor, AMD_DEVICE_ON_APPLE_VENDOR_ID,
450 AMD_DEVICE_ON_APPLE_VENDOR_STRING, AMD_DEVICE_VENDOR_ID, AMD_DEVICE_VENDOR_STRING,
451 NVIDIA_DEVICE_VENDOR_ID, NVIDIA_DEVICE_VENDOR_STRING,
452 };
453 use std::convert::TryFrom;
454
455 #[test]
456 fn test_device_all() {
457 let devices = Device::all();
458 for device in devices.iter() {
459 println!("device: {:?}", device);
460 }
461 assert!(!devices.is_empty(), "No supported GPU found.");
462 }
463
464 #[test]
465 fn test_vendor_from_str() {
466 assert_eq!(
467 Vendor::try_from(AMD_DEVICE_VENDOR_STRING).unwrap(),
468 Vendor::Amd,
469 "AMD vendor string can be converted."
470 );
471 assert_eq!(
472 Vendor::try_from(AMD_DEVICE_ON_APPLE_VENDOR_STRING).unwrap(),
473 Vendor::Amd,
474 "AMD vendor string (on apple) can be converted."
475 );
476 assert_eq!(
477 Vendor::try_from(NVIDIA_DEVICE_VENDOR_STRING).unwrap(),
478 Vendor::Nvidia,
479 "Nvidia vendor string can be converted."
480 );
481 assert!(matches!(
482 Vendor::try_from("unknown vendor"),
483 Err(GPUError::UnsupportedVendor(_))
484 ));
485 }
486
487 #[test]
488 fn test_vendor_from_u32() {
489 assert_eq!(
490 Vendor::try_from(AMD_DEVICE_VENDOR_ID).unwrap(),
491 Vendor::Amd,
492 "AMD vendor ID can be converted."
493 );
494 assert_eq!(
495 Vendor::try_from(AMD_DEVICE_ON_APPLE_VENDOR_ID).unwrap(),
496 Vendor::Amd,
497 "AMD vendor ID (on apple) can be converted."
498 );
499 assert_eq!(
500 Vendor::try_from(NVIDIA_DEVICE_VENDOR_ID).unwrap(),
501 Vendor::Nvidia,
502 "Nvidia vendor ID can be converted."
503 );
504 assert!(matches!(
505 Vendor::try_from(0x1abc),
506 Err(GPUError::UnsupportedVendor(_))
507 ));
508 }
509
510 #[test]
511 fn test_vendor_display() {
512 assert_eq!(
513 Vendor::Amd.to_string(),
514 AMD_DEVICE_VENDOR_STRING,
515 "AMD vendor can be converted to string."
516 );
517 assert_eq!(
518 Vendor::Nvidia.to_string(),
519 NVIDIA_DEVICE_VENDOR_STRING,
520 "Nvidia vendor can be converted to string."
521 );
522 }
523
524 #[test]
525 fn test_uuid() {
526 let valid_string = "46abccd6-022e-b783-572d-833f7104d05f";
527 let valid = DeviceUuid::try_from(valid_string).unwrap();
528 assert_eq!(valid_string, &valid.to_string());
529
530 let too_short_string = "ccd6-022e-b783-572d-833f7104d05f";
531 let too_short = DeviceUuid::try_from(too_short_string);
532 assert!(too_short.is_err(), "Parse error when UUID is too short.");
533
534 let invalid_hex_string = "46abccd6-022e-b783-572d-833f7104d05h";
535 let invalid_hex = DeviceUuid::try_from(invalid_hex_string);
536 assert!(
537 invalid_hex.is_err(),
538 "Parse error when UUID containts non-hex character."
539 );
540 }
541
542 #[test]
543 fn test_pci_id() {
544 let valid_string = "01:00";
545 let valid = PciId::try_from(valid_string).unwrap();
546 assert_eq!(valid_string, &valid.to_string());
547 assert_eq!(valid, PciId(0x0100));
548
549 let too_short_string = "3f";
550 let too_short = PciId::try_from(too_short_string);
551 assert!(too_short.is_err(), "Parse error when PCI ID is too short.");
552
553 let invalid_hex_string = "aaxx";
554 let invalid_hex = PciId::try_from(invalid_hex_string);
555 assert!(
556 invalid_hex.is_err(),
557 "Parse error when PCI ID containts non-hex character."
558 );
559 }
560
561 #[test]
562 fn test_unique_id() {
563 let valid_pci_id_string = "aa:bb";
564 let valid_pci_id = UniqueId::try_from(valid_pci_id_string).unwrap();
565 assert_eq!(valid_pci_id_string, &valid_pci_id.to_string());
566 assert_eq!(valid_pci_id, UniqueId::PciId(PciId(0xaabb)));
567
568 let valid_uuid_string = "aabbccdd-eeff-0011-2233-445566778899";
569 let valid_uuid = UniqueId::try_from(valid_uuid_string).unwrap();
570 assert_eq!(valid_uuid_string, &valid_uuid.to_string());
571 assert_eq!(
572 valid_uuid,
573 UniqueId::Uuid(DeviceUuid([
574 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
575 0x88, 0x99
576 ]))
577 );
578
579 let invalid_string = "aabbccddeeffgg";
580 let invalid = UniqueId::try_from(invalid_string);
581 assert!(
582 invalid.is_err(),
583 "Parse error when ID matches neither a PCI Id, nor a UUID."
584 );
585 }
586}