rust_gpu_tools/
device.rs

1use std::fmt;
2
3use log::debug;
4#[cfg(all(feature = "opencl", feature = "cuda"))]
5use log::warn;
6use once_cell::sync::Lazy;
7
8use std::convert::TryFrom;
9use std::mem;
10
11use crate::error::{GPUError, GPUResult};
12
13#[cfg(feature = "cuda")]
14use crate::cuda;
15#[cfg(feature = "opencl")]
16use crate::opencl;
17
18/// The UUID of the devices returned by OpenCL as well as CUDA are always 16 bytes long.
19const UUID_SIZE: usize = 16;
20const AMD_DEVICE_VENDOR_STRING: &str = "Advanced Micro Devices, Inc.";
21const AMD_DEVICE_VENDOR_ID: u32 = 0x1002;
22// For some reason integrated AMD cards on Apple don't have the usual vendor name and ID
23const AMD_DEVICE_ON_APPLE_VENDOR_STRING: &str = "AMD";
24const AMD_DEVICE_ON_APPLE_VENDOR_ID: u32 = 0x1021d00;
25const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation";
26const NVIDIA_DEVICE_VENDOR_ID: u32 = 0x10de;
27
28// The owned CUDA contexts are stored globally. Each devives contains an unowned reference, so
29// that devices can be cloned.
30#[cfg(feature = "cuda")]
31static DEVICES: Lazy<(Vec<Device>, cuda::utils::CudaContexts)> = Lazy::new(build_device_list);
32
33// Keep it as a tuple as the CUDA case, so that the using `DEVICES` is independent of the
34// features set.
35#[cfg(all(feature = "opencl", not(feature = "cuda")))]
36static DEVICES: Lazy<(Vec<Device>, ())> = Lazy::new(build_device_list);
37
38/// The PCI-ID is the combination of the PCI Bus ID and PCI Device ID.
39///
40/// It is the first two identifiers of e.g. `lcpci`:
41///
42/// ```text
43///     4e:00.0 VGA compatible controller
44///     || └└-- Device ID
45///     └└-- Bus ID
46/// ```
47#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq)]
48pub struct PciId(u16);
49
50impl From<u16> for PciId {
51    fn from(id: u16) -> Self {
52        Self(id)
53    }
54}
55
56impl From<PciId> for u16 {
57    fn from(id: PciId) -> Self {
58        id.0
59    }
60}
61
62/// Converts a PCI-ID formatted as Bus-ID:Device-ID, e.g. `e3:00`.
63impl TryFrom<&str> for PciId {
64    type Error = GPUError;
65
66    fn try_from(pci_id: &str) -> GPUResult<Self> {
67        let mut bytes = [0; mem::size_of::<u16>()];
68        hex::decode_to_slice(pci_id.replace(':', ""), &mut bytes).map_err(|_| {
69            GPUError::InvalidId(format!(
70                "Cannot parse PCI ID, expected hex-encoded string formated as aa:bb, got {0}.",
71                pci_id
72            ))
73        })?;
74        let parsed = u16::from_be_bytes(bytes);
75        Ok(Self(parsed))
76    }
77}
78
79/// Formats the PCI-ID like `lspci`, Bus-ID:Device-ID, e.g. `e3:00`.
80impl fmt::Display for PciId {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        let bytes = u16::to_be_bytes(self.0);
83        write!(f, "{:02x}:{:02x}", bytes[0], bytes[1])
84    }
85}
86
87/// A unique identifier based on UUID of the device.
88#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)]
89pub struct DeviceUuid([u8; UUID_SIZE]);
90
91impl From<[u8; UUID_SIZE]> for DeviceUuid {
92    fn from(uuid: [u8; UUID_SIZE]) -> Self {
93        Self(uuid)
94    }
95}
96
97impl From<DeviceUuid> for [u8; UUID_SIZE] {
98    fn from(uuid: DeviceUuid) -> Self {
99        uuid.0
100    }
101}
102
103/// Converts a UUID formatted as aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee,
104/// e.g. 46abccd6-022e-b783-572d-833f7104d05f
105impl TryFrom<&str> for DeviceUuid {
106    type Error = GPUError;
107
108    fn try_from(uuid: &str) -> GPUResult<Self> {
109        let mut bytes = [0; UUID_SIZE];
110        hex::decode_to_slice(uuid.replace('-', ""), &mut bytes)
111            .map_err(|_| {
112                GPUError::InvalidId(format!("Cannot parse UUID, expected hex-encoded string formated as aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee, got {0}.", uuid))
113            })?;
114        Ok(Self(bytes))
115    }
116}
117
118/// Formats the UUID the same way as `clinfo` does, as an example:
119/// the output should looks like 46abccd6-022e-b783-572d-833f7104d05f
120impl fmt::Display for DeviceUuid {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(
123            f,
124            "{}-{}-{}-{}-{}",
125            hex::encode(&self.0[..4]),
126            hex::encode(&self.0[4..6]),
127            hex::encode(&self.0[6..8]),
128            hex::encode(&self.0[8..10]),
129            hex::encode(&self.0[10..])
130        )
131    }
132}
133
134impl fmt::Debug for DeviceUuid {
135    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136        write!(f, "{}", self)
137    }
138}
139
140/// Unique identifier that can either be a PCI ID or a UUID.
141#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
142pub enum UniqueId {
143    /// ID based on the PCI bus.
144    PciId(PciId),
145    /// ID based on a globally unique identifier.
146    Uuid(DeviceUuid),
147}
148
149/// If the string contains a dash, it's interpreted as UUID, else it's interpreted as PCI ID.
150impl TryFrom<&str> for UniqueId {
151    type Error = GPUError;
152
153    fn try_from(unique_id: &str) -> GPUResult<Self> {
154        Ok(match unique_id.contains('-') {
155            true => Self::Uuid(DeviceUuid::try_from(unique_id)?),
156            false => Self::PciId(PciId::try_from(unique_id)?),
157        })
158    }
159}
160
161impl fmt::Display for UniqueId {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        match self {
164            Self::PciId(id) => id.fmt(f),
165            Self::Uuid(id) => id.fmt(f),
166        }
167    }
168}
169
170/// Currently supported vendors of this library.
171#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
172pub enum Vendor {
173    /// GPU by AMD.
174    Amd,
175    /// GPU by NVIDIA.
176    Nvidia,
177}
178
179impl TryFrom<&str> for Vendor {
180    type Error = GPUError;
181
182    fn try_from(vendor: &str) -> GPUResult<Self> {
183        match vendor {
184            AMD_DEVICE_VENDOR_STRING => Ok(Self::Amd),
185            AMD_DEVICE_ON_APPLE_VENDOR_STRING => Ok(Self::Amd),
186            NVIDIA_DEVICE_VENDOR_STRING => Ok(Self::Nvidia),
187            _ => Err(GPUError::UnsupportedVendor(vendor.to_string())),
188        }
189    }
190}
191
192impl TryFrom<u32> for Vendor {
193    type Error = GPUError;
194
195    fn try_from(vendor: u32) -> GPUResult<Self> {
196        match vendor {
197            AMD_DEVICE_VENDOR_ID => Ok(Self::Amd),
198            AMD_DEVICE_ON_APPLE_VENDOR_ID => Ok(Self::Amd),
199            NVIDIA_DEVICE_VENDOR_ID => Ok(Self::Nvidia),
200            _ => Err(GPUError::UnsupportedVendor(format!("0x{:x}", vendor))),
201        }
202    }
203}
204
205impl fmt::Display for Vendor {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        let vendor = match self {
208            Self::Amd => AMD_DEVICE_VENDOR_STRING,
209            Self::Nvidia => NVIDIA_DEVICE_VENDOR_STRING,
210        };
211        write!(f, "{}", vendor)
212    }
213}
214
215/// Which framework to use, CUDA or OpenCL.
216#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
217pub enum Framework {
218    /// CUDA.
219    #[cfg(feature = "cuda")]
220    Cuda,
221    /// OpenCL.
222    #[cfg(feature = "opencl")]
223    Opencl,
224}
225
226/// A device that may have a CUDA and/or OpenCL GPU associated with it.
227#[derive(Clone, Debug, Eq, Hash, PartialEq)]
228pub struct Device {
229    vendor: Vendor,
230    name: String,
231    memory: u64,
232    compute_units: u32,
233    /// Major and minor version of the compute capabilitiy (only available on Nvidia GPUs).
234    compute_capability: Option<(u32, u32)>,
235    // All devices have a PCI ID. It is used as fallback in case there is not UUID.
236    pci_id: PciId,
237    uuid: Option<DeviceUuid>,
238    #[cfg(feature = "cuda")]
239    cuda: Option<cuda::Device>,
240    #[cfg(feature = "opencl")]
241    opencl: Option<opencl::Device>,
242}
243
244impl Device {
245    /// Returns the [`Vendor`] of the GPU.
246    pub fn vendor(&self) -> Vendor {
247        self.vendor
248    }
249
250    /// Returns the name of the GPU, e.g. "GeForce RTX 3090".
251    pub fn name(&self) -> String {
252        self.name.clone()
253    }
254
255    /// Returns the memory of the GPU in bytes.
256    pub fn memory(&self) -> u64 {
257        self.memory
258    }
259
260    /// Returns the number of compute units of the GPU.
261    pub fn compute_units(&self) -> u32 {
262        self.compute_units
263    }
264
265    /// Returns the major and minor version of the compute capability (only available on Nvidia
266    /// GPUs).
267    pub fn compute_capability(&self) -> Option<(u32, u32)> {
268        self.compute_capability
269    }
270
271    /// Returns the best possible unique identifier, a UUID is preferred over a PCI ID.
272    pub fn unique_id(&self) -> UniqueId {
273        match self.uuid {
274            Some(uuid) => UniqueId::Uuid(uuid),
275            None => UniqueId::PciId(self.pci_id),
276        }
277    }
278
279    /// Returns the preferred framework (CUDA or OpenCL) to use.
280    ///
281    /// CUDA will be be preferred over OpenCL. The returned framework will work on the device.
282    /// E.g. it won't return `Framework::Cuda` for an AMD device.
283    pub fn framework(&self) -> Framework {
284        #[cfg(all(feature = "opencl", feature = "cuda"))]
285        if cfg!(feature = "cuda") && self.cuda.is_some() {
286            Framework::Cuda
287        } else {
288            Framework::Opencl
289        }
290
291        #[cfg(all(feature = "cuda", not(feature = "opencl")))]
292        {
293            Framework::Cuda
294        }
295
296        #[cfg(all(feature = "opencl", not(feature = "cuda")))]
297        {
298            Framework::Opencl
299        }
300    }
301
302    /// Returns the underlying CUDA device if it is available.
303    #[cfg(feature = "cuda")]
304    pub fn cuda_device(&self) -> Option<&cuda::Device> {
305        self.cuda.as_ref()
306    }
307
308    /// Returns the underlying OpenCL device if it is available.
309    #[cfg(feature = "opencl")]
310    pub fn opencl_device(&self) -> Option<&opencl::Device> {
311        self.opencl.as_ref()
312    }
313
314    /// Returns all available GPUs that are supported.
315    pub fn all() -> Vec<&'static Device> {
316        Self::all_iter().collect()
317    }
318
319    /// Returns the device matching the PCI ID if there is one.
320    pub fn by_pci_id(pci_id: PciId) -> Option<&'static Device> {
321        Self::all_iter().find(|d| pci_id == d.pci_id)
322    }
323
324    /// Returns the device matching the UUID if there is one.
325    pub fn by_uuid(uuid: DeviceUuid) -> Option<&'static Device> {
326        Self::all_iter().find(|d| Some(uuid) == d.uuid)
327    }
328
329    /// Returns the device matching the unique ID if there is one.
330    pub fn by_unique_id(unique_id: UniqueId) -> Option<&'static Device> {
331        Self::all_iter().find(|d| unique_id == d.unique_id())
332    }
333
334    /// Returns an iterator of all available GPUs that are supported.
335    fn all_iter() -> impl Iterator<Item = &'static Device> {
336        DEVICES.0.iter()
337    }
338}
339
340/// Get a list of all available and supported devices.
341///
342/// If both, the `cuda` and the `opencl` feature are enabled, a device supporting both will be
343/// combined into a single device. You can then access the underlying CUDA and OpenCL device
344/// if needed.
345///
346/// If there is a failure retrieving a device, it won't lead to a hard error, but an error will be
347/// logged and the corresponding device won't be available.
348#[cfg(feature = "cuda")]
349fn build_device_list() -> (Vec<Device>, cuda::utils::CudaContexts) {
350    let mut all_devices = Vec::new();
351
352    #[cfg(feature = "opencl")]
353    let opencl_devices = opencl::utils::build_device_list();
354
355    #[cfg(all(feature = "cuda", feature = "opencl"))]
356    let (mut cuda_devices, cuda_contexts) = cuda::utils::build_device_list();
357    #[cfg(all(feature = "cuda", not(feature = "opencl")))]
358    let (cuda_devices, cuda_contexts) = cuda::utils::build_device_list();
359
360    // Combine OpenCL and CUDA devices into one device if it is the same GPU
361    #[cfg(feature = "opencl")]
362    for opencl_device in opencl_devices {
363        let mut device = Device {
364            vendor: opencl_device.vendor(),
365            name: opencl_device.name(),
366            memory: opencl_device.memory(),
367            compute_units: opencl_device.compute_units(),
368            compute_capability: opencl_device.compute_capability(),
369            pci_id: opencl_device.pci_id(),
370            uuid: opencl_device.uuid(),
371            opencl: Some(opencl_device),
372            cuda: None,
373        };
374
375        // Only devices from Nvidia can use CUDA
376        #[cfg(feature = "cuda")]
377        if device.vendor == Vendor::Nvidia {
378            for ii in 0..cuda_devices.len() {
379                if (device.uuid.is_some() && cuda_devices[ii].uuid() == device.uuid)
380                    || (cuda_devices[ii].pci_id() == device.pci_id)
381                {
382                    if device.memory() != cuda_devices[ii].memory() {
383                        warn!("OpenCL and CUDA report different amounts of memory for a device with the same identifier");
384                        break;
385                    }
386                    if device.compute_units() != cuda_devices[ii].compute_units() {
387                        warn!("OpenCL and CUDA report different amounts of compute units for a device with the same identifier");
388                        break;
389                    }
390                    // Move the CUDA device out of the vector
391                    device.cuda = Some(cuda_devices.remove(ii));
392                    // Only one device can match
393                    break;
394                }
395            }
396        }
397
398        all_devices.push(device)
399    }
400
401    // All CUDA devices that don't have a corresponding OpenCL devices
402    for cuda_device in cuda_devices {
403        let device = Device {
404            vendor: cuda_device.vendor(),
405            name: cuda_device.name(),
406            memory: cuda_device.memory(),
407            compute_units: cuda_device.compute_units(),
408            compute_capability: Some(cuda_device.compute_capability()),
409            pci_id: cuda_device.pci_id(),
410            uuid: cuda_device.uuid(),
411            cuda: Some(cuda_device),
412            #[cfg(feature = "opencl")]
413            opencl: None,
414        };
415        all_devices.push(device);
416    }
417
418    debug!("loaded devices: {:?}", all_devices);
419    (all_devices, cuda_contexts)
420}
421
422/// Get a list of all available and supported OpenCL devices.
423///
424/// If there is a failure retrieving a device, it won't lead to a hard error, but an error will be
425/// logged and the corresponding device won't be available.
426#[cfg(all(feature = "opencl", not(feature = "cuda")))]
427fn build_device_list() -> (Vec<Device>, ()) {
428    let devices = opencl::utils::build_device_list()
429        .into_iter()
430        .map(|device| Device {
431            vendor: device.vendor(),
432            name: device.name(),
433            memory: device.memory(),
434            compute_units: device.compute_units(),
435            compute_capability: device.compute_capability(),
436            pci_id: device.pci_id(),
437            uuid: device.uuid(),
438            opencl: Some(device),
439        })
440        .collect();
441
442    debug!("loaded devices: {:?}", devices);
443    (devices, ())
444}
445
446#[cfg(test)]
447mod test {
448    use super::{
449        Device, DeviceUuid, GPUError, PciId, UniqueId, Vendor, AMD_DEVICE_ON_APPLE_VENDOR_ID,
450        AMD_DEVICE_ON_APPLE_VENDOR_STRING, AMD_DEVICE_VENDOR_ID, AMD_DEVICE_VENDOR_STRING,
451        NVIDIA_DEVICE_VENDOR_ID, NVIDIA_DEVICE_VENDOR_STRING,
452    };
453    use std::convert::TryFrom;
454
455    #[test]
456    fn test_device_all() {
457        let devices = Device::all();
458        for device in devices.iter() {
459            println!("device: {:?}", device);
460        }
461        assert!(!devices.is_empty(), "No supported GPU found.");
462    }
463
464    #[test]
465    fn test_vendor_from_str() {
466        assert_eq!(
467            Vendor::try_from(AMD_DEVICE_VENDOR_STRING).unwrap(),
468            Vendor::Amd,
469            "AMD vendor string can be converted."
470        );
471        assert_eq!(
472            Vendor::try_from(AMD_DEVICE_ON_APPLE_VENDOR_STRING).unwrap(),
473            Vendor::Amd,
474            "AMD vendor string (on apple) can be converted."
475        );
476        assert_eq!(
477            Vendor::try_from(NVIDIA_DEVICE_VENDOR_STRING).unwrap(),
478            Vendor::Nvidia,
479            "Nvidia vendor string can be converted."
480        );
481        assert!(matches!(
482            Vendor::try_from("unknown vendor"),
483            Err(GPUError::UnsupportedVendor(_))
484        ));
485    }
486
487    #[test]
488    fn test_vendor_from_u32() {
489        assert_eq!(
490            Vendor::try_from(AMD_DEVICE_VENDOR_ID).unwrap(),
491            Vendor::Amd,
492            "AMD vendor ID can be converted."
493        );
494        assert_eq!(
495            Vendor::try_from(AMD_DEVICE_ON_APPLE_VENDOR_ID).unwrap(),
496            Vendor::Amd,
497            "AMD vendor ID (on apple) can be converted."
498        );
499        assert_eq!(
500            Vendor::try_from(NVIDIA_DEVICE_VENDOR_ID).unwrap(),
501            Vendor::Nvidia,
502            "Nvidia vendor ID can be converted."
503        );
504        assert!(matches!(
505            Vendor::try_from(0x1abc),
506            Err(GPUError::UnsupportedVendor(_))
507        ));
508    }
509
510    #[test]
511    fn test_vendor_display() {
512        assert_eq!(
513            Vendor::Amd.to_string(),
514            AMD_DEVICE_VENDOR_STRING,
515            "AMD vendor can be converted to string."
516        );
517        assert_eq!(
518            Vendor::Nvidia.to_string(),
519            NVIDIA_DEVICE_VENDOR_STRING,
520            "Nvidia vendor can be converted to string."
521        );
522    }
523
524    #[test]
525    fn test_uuid() {
526        let valid_string = "46abccd6-022e-b783-572d-833f7104d05f";
527        let valid = DeviceUuid::try_from(valid_string).unwrap();
528        assert_eq!(valid_string, &valid.to_string());
529
530        let too_short_string = "ccd6-022e-b783-572d-833f7104d05f";
531        let too_short = DeviceUuid::try_from(too_short_string);
532        assert!(too_short.is_err(), "Parse error when UUID is too short.");
533
534        let invalid_hex_string = "46abccd6-022e-b783-572d-833f7104d05h";
535        let invalid_hex = DeviceUuid::try_from(invalid_hex_string);
536        assert!(
537            invalid_hex.is_err(),
538            "Parse error when UUID containts non-hex character."
539        );
540    }
541
542    #[test]
543    fn test_pci_id() {
544        let valid_string = "01:00";
545        let valid = PciId::try_from(valid_string).unwrap();
546        assert_eq!(valid_string, &valid.to_string());
547        assert_eq!(valid, PciId(0x0100));
548
549        let too_short_string = "3f";
550        let too_short = PciId::try_from(too_short_string);
551        assert!(too_short.is_err(), "Parse error when PCI ID is too short.");
552
553        let invalid_hex_string = "aaxx";
554        let invalid_hex = PciId::try_from(invalid_hex_string);
555        assert!(
556            invalid_hex.is_err(),
557            "Parse error when PCI ID containts non-hex character."
558        );
559    }
560
561    #[test]
562    fn test_unique_id() {
563        let valid_pci_id_string = "aa:bb";
564        let valid_pci_id = UniqueId::try_from(valid_pci_id_string).unwrap();
565        assert_eq!(valid_pci_id_string, &valid_pci_id.to_string());
566        assert_eq!(valid_pci_id, UniqueId::PciId(PciId(0xaabb)));
567
568        let valid_uuid_string = "aabbccdd-eeff-0011-2233-445566778899";
569        let valid_uuid = UniqueId::try_from(valid_uuid_string).unwrap();
570        assert_eq!(valid_uuid_string, &valid_uuid.to_string());
571        assert_eq!(
572            valid_uuid,
573            UniqueId::Uuid(DeviceUuid([
574                0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
575                0x88, 0x99
576            ]))
577        );
578
579        let invalid_string = "aabbccddeeffgg";
580        let invalid = UniqueId::try_from(invalid_string);
581        assert!(
582            invalid.is_err(),
583            "Parse error when ID matches neither a PCI Id, nor a UUID."
584        );
585    }
586}