1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use std::convert::TryFrom;

use log::{debug, warn};

use crate::cuda::Device;
use crate::device::{PciId, Vendor};
use crate::error::{GPUError, GPUResult};

// NOTE vmx 2021-04-14: This is a hack to make sure contexts stay around. We wrap them, so that
// `Sync` can be implemented. `Sync` is needed for lazy static. These contexts are never used
// directly, they are only accessed through [`cuda::Device`] which contains an `UnownedContext`.
// A device cannot have an own context itself, as then it couldn't be cloned, but that is needed
// for creating the kernels.
pub(crate) struct CudaContexts(Vec<rustacuda::context::Context>);
unsafe impl Sync for CudaContexts {}
unsafe impl Send for CudaContexts {}

/// The PCI-ID is the combination of the PCI Bus ID and PCI Device ID.
///
/// It is the first two identifiers of e.g. `lspci`:
///
/// ```text
///     4e:00.0 VGA compatible controller
///     || └└-- Device ID
///     └└-- Bus ID
/// ```
fn get_pci_id(device: &rustacuda::device::Device) -> Result<PciId, GPUError> {
    let bus_id = device.get_attribute(rustacuda::device::DeviceAttribute::PciBusId)? as u16;
    let device_id = device.get_attribute(rustacuda::device::DeviceAttribute::PciDeviceId)? as u16;
    let pci_id = (bus_id << 8) | device_id;
    Ok(pci_id.into())
}

fn get_memory(d: &rustacuda::device::Device) -> GPUResult<u64> {
    let memory = d.total_memory()?;
    Ok(u64::try_from(memory).expect("Platform must be <= 64-bit"))
}

fn get_compute_units(d: &rustacuda::device::Device) -> GPUResult<u32> {
    let compute_units = d.get_attribute(rustacuda::device::DeviceAttribute::MultiprocessorCount)?;
    Ok(u32::try_from(compute_units).expect("The number of units is always positive"))
}

/// Get the major an minor version of the compute capability.
fn get_compute_capability(d: &rustacuda::device::Device) -> GPUResult<(u32, u32)> {
    let major = d.get_attribute(rustacuda::device::DeviceAttribute::ComputeCapabilityMajor)?;
    let minor = d.get_attribute(rustacuda::device::DeviceAttribute::ComputeCapabilityMinor)?;
    Ok((
        u32::try_from(major).expect("The compute capability major version is always positive"),
        u32::try_from(minor).expect("The compute capability minor version is always positive"),
    ))
}

/// Get a list of all available and supported devices.
///
/// If there is a failure initializing CUDA or retrieving a device, it won't lead to a hard error,
/// but an error will be logged and the corresponding device won't be available.
pub(crate) fn build_device_list() -> (Vec<Device>, CudaContexts) {
    let mut all_devices = Vec::new();
    let mut devices_without_pci_id = Vec::new();
    let mut contexts = Vec::new();

    rustacuda::init(rustacuda::CudaFlags::empty())
        .map_err(Into::into)
        .and_then(|_| {
            for device in rustacuda::device::Device::devices()? {
                let device = device?;
                let owned_context = rustacuda::context::Context::create_and_push(
                    rustacuda::context::ContextFlags::MAP_HOST
                        | rustacuda::context::ContextFlags::SCHED_AUTO,
                    device,
                )?;
                rustacuda::context::ContextStack::pop()?;

                let vendor = Vendor::Nvidia;
                let name = device.name()?;
                let memory = get_memory(&device)?;
                let compute_units = get_compute_units(&device)?;
                let compute_capability = get_compute_capability(&device)?;
                let uuid = device.uuid().ok().map(Into::into);
                let context = owned_context.get_unowned();

                contexts.push(owned_context);

                // If a device doesn't have a PCI-ID, add those later to the list of
                // devices with a fake PCI-ID.
                match get_pci_id(&device) {
                    Ok(pci_id) => {
                        all_devices.push(Device {
                            vendor,
                            name,
                            memory,
                            compute_units,
                            compute_capability,
                            pci_id,
                            uuid,
                            device,
                            context,
                        });
                    }
                    Err(_) => {
                        // Use a temporary PCI-ID and replace it later with a non-colliding one.
                        let pci_id = PciId::from(0);
                        devices_without_pci_id.push(Device {
                            vendor,
                            name,
                            memory,
                            compute_units,
                            compute_capability,
                            pci_id,
                            uuid,
                            device,
                            context,
                        });
                    }
                };
            }

            // Laptops might have an integrated GPU. Such devices might have neither a PCI-ID, nor a UUID.
            // As those devices are used for development and not for production use, it's good enough to
            // provide a workaround which doesn't add much complexity to the code. We use a fake PCI-ID
            // instead, which is generated by enumerating the available devices. In order to make that
            // case easier to spot when debugging issues, a starting number which is pleasant to the human
            // eye was choosen, that works both, decimal and hexadecimal (4660 == 0x1234).
            let mut enumerated_device: u16 = 4660;
            for mut device in devices_without_pci_id.into_iter() {
                // Make sure that no device has that actual PCI-ID
                while all_devices
                    .iter()
                    .any(|d| d.pci_id() == enumerated_device.into())
                {
                    enumerated_device += 1;
                }
                device.pci_id = PciId::from(enumerated_device);
                enumerated_device += 1;
                all_devices.push(device);
            }

            let wrapped_contexts = CudaContexts(contexts);

            debug!("Loaded CUDA devices: {:?}", all_devices);
            Ok((all_devices, wrapped_contexts))
        })
        .unwrap_or_else(|error: GPUError| {
            warn!("Unable to retrieve CUDA devices: {:?}", error);
            (Vec::new(), CudaContexts(Vec::new()))
        })
}