vkfetch_rs/
device.rs

1use ash::Instance;
2use ash::vk;
3use ash::vk::PhysicalDeviceProperties2;
4use ash::vk::PhysicalDeviceShaderCoreProperties2AMD;
5use ash::vk::PhysicalDeviceShaderCorePropertiesAMD;
6use ash::vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV;
7use std::ffi::CStr;
8
9use crate::vendor::Vendor;
10
11/// Represents a physical GPU device.
12#[derive(Debug)]
13pub struct Device {
14    pub vendor: Vendor,
15    pub device_name: String,
16    pub device_type: DeviceType,
17    pub device_id: u32,
18    pub vendor_id: u32,
19    pub driver_name: String,
20    pub driver_info: String,
21    pub api_version: String,
22    // VRAM:
23    pub heapbudget: u64,
24    pub heapsize: u64,
25    pub characteristics: GPUCharacteristics,
26}
27
28/// Contains various characteristics of a GPU.
29/// Vendor-specific properties are stored as Options.
30/// Also includes some general device limits.
31#[derive(Debug)]
32pub struct GPUCharacteristics {
33    /// Memory pressure as computed from VRAM usage (0.0 to 1.0)
34    pub memory_pressure: f32,
35    // AMD-specific properties.
36    pub compute_units: Option<u32>,
37    pub shader_engines: Option<u32>,
38    pub shader_arrays_per_engine_count: Option<u32>,
39    pub compute_units_per_shader_array: Option<u32>,
40    pub simd_per_compute_unit: Option<u32>,
41    pub wavefronts_per_simd: Option<u32>,
42    pub wavefront_size: Option<u32>,
43    // NVIDIA-specific properties.
44    pub streaming_multiprocessors: Option<u32>,
45    pub warps_per_sm: Option<u32>,
46    // General device limits.
47    pub max_image_dimension_2d: u32,
48    pub max_compute_shared_memory_size: u32,
49    pub max_compute_work_group_invocations: u32,
50    // New feature flags.
51    pub dedicated_transfer_queue: bool,
52    pub dedicated_async_compute_queue: bool,
53    pub supports_ray_tracing: bool,
54}
55
56impl Device {
57    /// Constructs a new `PhysicalDevice` by querying Vulkan properties.
58    pub fn new(instance: &Instance, physical_device: vk::PhysicalDevice) -> Self {
59        // Get the core properties and limits.
60        let physical_device_properties: vk::PhysicalDeviceProperties =
61            unsafe { instance.get_physical_device_properties(physical_device) };
62        let limits = physical_device_properties.limits;
63
64        // Query additional driver properties.
65        let mut driver_properties: vk::PhysicalDeviceDriverProperties =
66            vk::PhysicalDeviceDriverProperties::default();
67        let mut properties2: PhysicalDeviceProperties2 =
68            PhysicalDeviceProperties2::default().push_next(&mut driver_properties);
69        unsafe {
70            instance.get_physical_device_properties2(physical_device, &mut properties2);
71        }
72
73        let vendor_id = physical_device_properties.vendor_id;
74        let vendor = Vendor::from_vendor_id(vendor_id).unwrap_or_else(|| {
75            eprintln!("Unknown vendor: {}", vendor_id);
76            panic!();
77        });
78
79        let device_name = cstring_to_string(
80            physical_device_properties
81                .device_name_as_c_str()
82                .unwrap_or(c"Unknown"),
83        );
84        let device_type = DeviceType::from(physical_device_properties.device_type.as_raw());
85        let device_id = physical_device_properties.device_id;
86        let api_version = decode_version_number(physical_device_properties.api_version);
87        let driver_name = cstring_to_string(
88            driver_properties
89                .driver_name_as_c_str()
90                .unwrap_or(c"Unknown"),
91        );
92        let driver_info = cstring_to_string(
93            driver_properties
94                .driver_info_as_c_str()
95                .unwrap_or(c"Unknown"),
96        );
97
98        // Query VRAM details.
99        let mut memory_budget = vk::PhysicalDeviceMemoryBudgetPropertiesEXT::default();
100        let mut memory_properties2 =
101            vk::PhysicalDeviceMemoryProperties2::default().push_next(&mut memory_budget);
102        unsafe {
103            instance
104                .get_physical_device_memory_properties2(physical_device, &mut memory_properties2);
105        }
106        let memory_properties = memory_properties2.memory_properties;
107        let vram_heap_index = (0..memory_properties.memory_heap_count)
108            .find(|&i| {
109                memory_properties.memory_heaps[i as usize]
110                    .flags
111                    .contains(vk::MemoryHeapFlags::DEVICE_LOCAL)
112            })
113            .unwrap_or(0);
114        let heapsize = memory_properties.memory_heaps[vram_heap_index as usize].size;
115        let heapbudget = memory_budget.heap_budget[vram_heap_index as usize];
116        let memory_pressure = if heapbudget > 0 {
117            (heapsize - heapbudget) as f32 / heapsize as f32
118        } else {
119            f32::NAN
120        };
121
122        // Query queue family properties.
123        let queue_families =
124            unsafe { instance.get_physical_device_queue_family_properties(physical_device) };
125        let mut dedicated_transfer_queue = false;
126        let mut dedicated_async_compute_queue = false;
127        for qf in queue_families.iter() {
128            let flags = qf.queue_flags;
129            if flags.contains(vk::QueueFlags::TRANSFER)
130                && !(flags.contains(vk::QueueFlags::GRAPHICS)
131                    || flags.contains(vk::QueueFlags::COMPUTE))
132            {
133                dedicated_transfer_queue = true;
134            }
135            if flags.contains(vk::QueueFlags::COMPUTE) && !flags.contains(vk::QueueFlags::GRAPHICS)
136            {
137                dedicated_async_compute_queue = true;
138            }
139        }
140
141        // Check for ray tracing support via device extensions.
142        let extensions = unsafe {
143            instance
144                .enumerate_device_extension_properties(physical_device)
145                .unwrap_or_default()
146        };
147        let supports_ray_tracing = extensions.iter().any(|ext| {
148            let ext_name = unsafe { CStr::from_ptr(ext.extension_name.as_ptr()) };
149            ext_name.to_str().unwrap_or("") == "VK_KHR_ray_tracing_pipeline"
150                || ext_name.to_str().unwrap_or("") == "VK_NV_ray_tracing"
151        });
152
153        let mut characteristics = GPUCharacteristics {
154            memory_pressure,
155            // Vendor-specific fields start as None.
156            compute_units: None,
157            shader_engines: None,
158            shader_arrays_per_engine_count: None,
159            compute_units_per_shader_array: None,
160            simd_per_compute_unit: None,
161            wavefronts_per_simd: None,
162            wavefront_size: None,
163            streaming_multiprocessors: None,
164            warps_per_sm: None,
165            // General limits:
166            max_image_dimension_2d: limits.max_image_dimension2_d,
167            max_compute_shared_memory_size: limits.max_compute_shared_memory_size,
168            max_compute_work_group_invocations: limits.max_compute_work_group_invocations,
169            // New features:
170            dedicated_transfer_queue,
171            dedicated_async_compute_queue,
172            supports_ray_tracing,
173        };
174
175        // Query vendor-specific properties.
176        match vendor {
177            Vendor::AMD => {
178                let mut shader_core_properties = PhysicalDeviceShaderCorePropertiesAMD::default();
179                let mut shader_core_properties2 = PhysicalDeviceShaderCoreProperties2AMD::default();
180                let mut amd_properties2 = PhysicalDeviceProperties2::default()
181                    .push_next(&mut shader_core_properties)
182                    .push_next(&mut shader_core_properties2);
183                unsafe {
184                    instance.get_physical_device_properties2(physical_device, &mut amd_properties2);
185                }
186                characteristics.compute_units = Some(
187                    shader_core_properties.shader_engine_count
188                        * shader_core_properties.shader_arrays_per_engine_count
189                        * shader_core_properties.compute_units_per_shader_array,
190                );
191                characteristics.shader_engines = Some(shader_core_properties.shader_engine_count);
192                characteristics.shader_arrays_per_engine_count =
193                    Some(shader_core_properties.shader_arrays_per_engine_count);
194                characteristics.compute_units_per_shader_array =
195                    Some(shader_core_properties.compute_units_per_shader_array);
196                characteristics.simd_per_compute_unit =
197                    Some(shader_core_properties.simd_per_compute_unit);
198                characteristics.wavefronts_per_simd =
199                    Some(shader_core_properties.wavefronts_per_simd);
200                characteristics.wavefront_size = Some(shader_core_properties.wavefront_size);
201            }
202            Vendor::Nvidia => {
203                let mut sm_builtins = PhysicalDeviceShaderSMBuiltinsPropertiesNV::default();
204                let mut nv_properties2 =
205                    PhysicalDeviceProperties2::default().push_next(&mut sm_builtins);
206                unsafe {
207                    instance.get_physical_device_properties2(physical_device, &mut nv_properties2);
208                }
209                characteristics.streaming_multiprocessors = Some(sm_builtins.shader_sm_count);
210                characteristics.warps_per_sm = Some(sm_builtins.shader_warps_per_sm);
211            }
212            _ => {
213                // For other vendors, vendor-specific fields remain None.
214            }
215        };
216
217        Device {
218            vendor,
219            device_name,
220            device_type,
221            device_id,
222            vendor_id,
223            driver_name,
224            driver_info,
225            api_version,
226            heapbudget,
227            heapsize,
228            characteristics,
229        }
230    }
231}
232
233/// Represents the type of device.
234#[derive(Debug)]
235pub enum DeviceType {
236    Other = 0,
237    IntegratedGPU = 1,
238    DiscreteGPU = 2,
239    VirtualGPU = 3,
240    CPU = 4,
241    Unknown = 5,
242}
243
244impl DeviceType {
245    /// Converts an integer ID (from Vulkan) into a DeviceType.
246    pub fn from(id: i32) -> Self {
247        match id {
248            0 => DeviceType::Other,
249            1 => DeviceType::IntegratedGPU,
250            2 => DeviceType::DiscreteGPU,
251            3 => DeviceType::VirtualGPU,
252            4 => DeviceType::CPU,
253            _ => DeviceType::Unknown,
254        }
255    }
256
257    /// Returns a human‑readable name.
258    pub fn name(&self) -> &'static str {
259        match self {
260            DeviceType::Other => "Other",
261            DeviceType::IntegratedGPU => "Integrated GPU",
262            DeviceType::DiscreteGPU => "Discrete GPU",
263            DeviceType::VirtualGPU => "Virtual GPU",
264            DeviceType::CPU => "CPU",
265            DeviceType::Unknown => "Unknown",
266        }
267    }
268}
269
270/// Decodes a Vulkan version number into a string of the form "variant.major.minor.patch".
271pub fn decode_version_number(version: u32) -> String {
272    let variant = (version >> 29) & 0b111;
273    let major = (version >> 22) & 0b1111111;
274    let minor = (version >> 12) & 0b1111111111;
275    let patch = version & 0b111111111111;
276    format!("{}.{}.{}.{}", variant, major, minor, patch)
277}
278
279/// Converts a CStr to a Rust String.
280pub fn cstring_to_string(cstr: &CStr) -> String {
281    cstr.to_string_lossy().into_owned()
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use ash::vk;
288    use std::ffi::CString;
289
290    // Helper to create a dummy CString.
291    fn dummy_cstr(s: &str) -> CString {
292        CString::new(s).unwrap()
293    }
294
295    #[test]
296    fn test_decode_version_number() {
297        // Simulate a Vulkan version: variant 0, version 1.2.3
298        let version: u32 = (1 << 22) | (2 << 12) | 3;
299        let decoded = decode_version_number(version);
300        assert_eq!(decoded, "0.1.2.3");
301    }
302
303    #[test]
304    fn test_cstring_to_string() {
305        let original = "Hello, world!";
306        let cstr = dummy_cstr(original);
307        let s = cstring_to_string(cstr.as_c_str());
308        assert_eq!(s, original);
309    }
310
311    #[test]
312    fn test_device_type_from() {
313        assert_eq!(DeviceType::from(0).name(), "Other");
314        assert_eq!(DeviceType::from(1).name(), "Integrated GPU");
315        assert_eq!(DeviceType::from(2).name(), "Discrete GPU");
316        assert_eq!(DeviceType::from(3).name(), "Virtual GPU");
317        assert_eq!(DeviceType::from(4).name(), "CPU");
318        assert_eq!(DeviceType::from(99).name(), "Unknown");
319    }
320
321    #[test]
322    fn test_gpu_characteristics_defaults() {
323        // Create dummy limits.
324        let limits = vk::PhysicalDeviceLimits {
325            max_image_dimension2_d: 8192,
326            max_compute_shared_memory_size: 16384,
327            max_compute_work_group_invocations: 1024,
328            ..Default::default()
329        };
330
331        // Construct dummy GPUCharacteristics with only common limits.
332        let characteristics = GPUCharacteristics {
333            memory_pressure: 0.5,
334            compute_units: None,
335            shader_engines: None,
336            shader_arrays_per_engine_count: None,
337            compute_units_per_shader_array: None,
338            simd_per_compute_unit: None,
339            wavefronts_per_simd: None,
340            wavefront_size: None,
341            streaming_multiprocessors: None,
342            warps_per_sm: None,
343            max_image_dimension_2d: limits.max_image_dimension2_d,
344            max_compute_shared_memory_size: limits.max_compute_shared_memory_size,
345            max_compute_work_group_invocations: limits.max_compute_work_group_invocations,
346            dedicated_transfer_queue: false,
347            dedicated_async_compute_queue: false,
348            supports_ray_tracing: false,
349        };
350
351        assert_eq!(characteristics.max_image_dimension_2d, 8192);
352        assert_eq!(characteristics.max_compute_shared_memory_size, 16384);
353        assert_eq!(characteristics.max_compute_work_group_invocations, 1024);
354        assert!(characteristics.compute_units.is_none());
355    }
356}