vkfetch_rs/
lib.rs

1pub mod ascii_art;
2pub mod device;
3pub mod vendor;
4
5use ash::{self, Entry, Instance, vk};
6use device::Device;
7use std::{
8    error::Error,
9    io::{self, Write},
10};
11use vendor::Vendor;
12use vt::enable_virtual_terminal_processing;
13
14const BOLD: &str = "\x1B[1m";
15const RESET: &str = "\x1B[0m";
16const WRAP_OFF: &str = "\x1B[?7l";
17const WRAP_ON: &str = "\x1B[?7h";
18const ALIGNMENT: &str = "    ";
19const EMPTY: &str = "";
20
21/// Fetches and prints information for a given physical device.
22pub fn fetch_device(
23    instance: &Instance,
24    device_handle: vk::PhysicalDevice,
25) -> Result<(), Box<dyn Error>> {
26    let properties = unsafe { instance.get_physical_device_properties(device_handle) };
27    let mut properties2 = vk::PhysicalDeviceProperties2::default();
28    unsafe {
29        instance.get_physical_device_properties2(device_handle, &mut properties2);
30    }
31
32    let vendor = Vendor::from_vendor_id(properties.vendor_id)
33        .unwrap_or_else(|| panic!("unknown vendor: {}", properties.vendor_id));
34    let art = vendor.get_ascii_art();
35
36    let info = get_device_info(
37        &Device::new(instance, device_handle),
38        (if is_ansi_supported() {
39            vendor.get_alternative_style()
40        } else {
41            vendor.get_style()
42        })[0],
43    );
44
45    let _ = enable_virtual_terminal_processing();
46
47    for i in 0..art.len().max(info.len()) {
48        let art_line = art
49            .get(i)
50            .map(String::as_str)
51            .unwrap_or(r#"                                               "#);
52        let info_line = info.get(i).map(String::as_str).unwrap_or(EMPTY);
53
54        if is_ansi_supported() {
55            print!("{}", WRAP_OFF);
56            io::stdout().flush()?;
57        }
58
59        println!(" {} {}", art_line, info_line);
60    }
61    if is_ansi_supported() {
62        print!("{}", WRAP_ON);
63        io::stdout().flush()?;
64    }
65
66    println!();
67    Ok(())
68}
69
70/// Returns a vector of formatted strings representing the device info,
71/// including extra vendor-specific and general device limits.
72/// Lines for optional fields are only included if available.
73fn get_device_info(device: &Device, color: &str) -> Vec<String> {
74    let mut lines = Vec::new();
75
76    let title = format!(
77        "{}{}{}{}: {}",
78        BOLD,
79        color,
80        device.device_name,
81        RESET,
82        device.device_type.name()
83    );
84    let underline_len = device.device_name.len() + device.device_type.name().len() + 3;
85    let underline = "=".repeat(underline_len);
86
87    let meter_width = 30;
88    let filled = (device.characteristics.memory_pressure * meter_width as f32).round() as usize;
89
90    // Basic device info.
91    lines.push(title);
92    lines.push(format!("{}{}{}", BOLD, color, underline));
93    lines.push(format!(
94        "{}{}Device{}: 0x{:X} : 0x{:X} ({})",
95        ALIGNMENT,
96        color,
97        RESET,
98        device.device_id,
99        device.vendor_id,
100        device.vendor.name(),
101    ));
102    lines.push(format!(
103        "{}{}Driver{}: {} : {}",
104        ALIGNMENT, color, RESET, device.driver_name, device.driver_info
105    ));
106    lines.push(format!(
107        "{}{}API{}: {}",
108        ALIGNMENT, color, RESET, device.api_version
109    ));
110    lines.push(format!(
111        "{}{}VRAM{}: {}{}{} / {}",
112        ALIGNMENT,
113        color,
114        RESET,
115        color,
116        format_bytes(device.heapbudget),
117        RESET,
118        format_bytes(device.heapsize)
119    ));
120    lines.push(format!(
121        "{}[{}{}{}{}] % {}{:.2}{}",
122        ALIGNMENT,
123        color,
124        "|".repeat(filled),
125        RESET,
126        " ".repeat(meter_width - filled),
127        color,
128        device.characteristics.memory_pressure * 100.0,
129        RESET
130    ));
131
132    // Vendor-specific extra info.
133    if let Some(cu) = device.characteristics.compute_units {
134        lines.push(format!(
135            "{}{}Compute Units{}: {}",
136            ALIGNMENT, color, RESET, cu
137        ));
138    }
139    if let Some(se) = device.characteristics.shader_engines {
140        lines.push(format!(
141            "{}{}Shader Engines{}: {}",
142            ALIGNMENT, color, RESET, se
143        ));
144    }
145    if let Some(sapec) = device.characteristics.shader_arrays_per_engine_count {
146        lines.push(format!(
147            "{}{}Shader Arrays per Engine{}: {}",
148            ALIGNMENT, color, RESET, sapec
149        ));
150    }
151    if let Some(cups) = device.characteristics.compute_units_per_shader_array {
152        lines.push(format!(
153            "{}{}Compute Units per Shader Array{}: {}",
154            ALIGNMENT, color, RESET, cups
155        ));
156    }
157    if let Some(simd) = device.characteristics.simd_per_compute_unit {
158        lines.push(format!(
159            "{}{}SIMD per Compute Unit{}: {}",
160            ALIGNMENT, color, RESET, simd
161        ));
162    }
163    if let Some(wfs) = device.characteristics.wavefronts_per_simd {
164        lines.push(format!(
165            "{}{}Wavefronts per SIMD{}: {}",
166            ALIGNMENT, color, RESET, wfs
167        ));
168    }
169    if let Some(wfsz) = device.characteristics.wavefront_size {
170        lines.push(format!(
171            "{}{}Wavefront Size{}: {}",
172            ALIGNMENT, color, RESET, wfsz
173        ));
174    }
175    if let Some(sm) = device.characteristics.streaming_multiprocessors {
176        lines.push(format!(
177            "{}{}Streaming Multiprocessors{}: {}",
178            ALIGNMENT, color, RESET, sm
179        ));
180    }
181    if let Some(wps) = device.characteristics.warps_per_sm {
182        lines.push(format!(
183            "{}{}Warps per SM{}: {}",
184            ALIGNMENT, color, RESET, wps
185        ));
186    }
187
188    // General device limits.
189    // lines.push(format!(
190    //     "{}{}Max Image Dimension 2D{}: {}",
191    //     ALIGNMENT,
192    //     color,
193    //     RESET,
194    //     format_bytes(device.characteristics.max_image_dimension_2d.into())
195    // ));
196    lines.push(format!(
197        "{}{}Max Compute Shared Memory Size{}: {}",
198        ALIGNMENT,
199        color,
200        RESET,
201        format_bytes(device.characteristics.max_compute_shared_memory_size.into())
202    ));
203    lines.push(format!(
204        "{}{}Max Compute Work Group Invocations{}: {}",
205        ALIGNMENT,
206        color,
207        RESET,
208        format_bytes(
209            device
210                .characteristics
211                .max_compute_work_group_invocations
212                .into()
213        )
214    ));
215
216    let checkbox = |b: bool| if b { "[x]" } else { "[ ]" };
217    let x = checkbox(device.characteristics.supports_ray_tracing);
218    let y = checkbox(device.characteristics.dedicated_transfer_queue);
219    let z = checkbox(device.characteristics.dedicated_async_compute_queue);
220
221    lines.push(format!(
222        "{}{}Raytracing{}: {} | {}Dedicated Transfer Queue{}: {} | {}Dedicated Async Compute Queue{}: {}",
223        ALIGNMENT,
224        color, RESET, x,
225        color, RESET, y,
226        color, RESET, z,
227    ));
228
229    lines
230}
231
232/// Converts a byte count into a human‐readable string with up to TB precision.
233fn format_bytes(bytes: u64) -> String {
234    const KB: f64 = 1024.0;
235    const MB: f64 = KB * 1024.0;
236    const GB: f64 = MB * 1024.0;
237    const TB: f64 = GB * 1024.0;
238    let bytes_f64 = bytes as f64;
239    if bytes_f64 >= TB {
240        format!("{:.3} TB", bytes_f64 / TB)
241    } else if bytes_f64 >= GB {
242        format!("{:.3} GB", bytes_f64 / GB)
243    } else if bytes_f64 >= MB {
244        format!("{:.3} MB", bytes_f64 / MB)
245    } else if bytes_f64 >= KB {
246        format!("{:.3} KB", bytes_f64 / KB)
247    } else {
248        format!("{} B", bytes)
249    }
250}
251
252/// Iterates through API versions and prints info for every physical device
253pub fn iterate_devices() -> Result<(), Box<dyn Error>> {
254    let entry = {
255        #[cfg(not(feature = "loaded"))]
256        {
257            Entry::linked()
258        }
259        #[cfg(feature = "loaded")]
260        {
261            match unsafe { Entry::load() } {
262                Ok(entry) => entry,
263                Err(e) => {
264                    eprintln!("Failed to load entry: {:?}", e);
265                    return Ok(());
266                }
267            }
268        }
269    };
270
271    for api_version in [
272        vk::API_VERSION_1_3,
273        vk::API_VERSION_1_2,
274        vk::API_VERSION_1_1,
275        vk::API_VERSION_1_0,
276    ] {
277        let app_info = vk::ApplicationInfo {
278            api_version,
279            ..Default::default()
280        };
281        let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
282
283        match unsafe { entry.create_instance(&create_info, None) } {
284            Ok(instance) => match unsafe { instance.enumerate_physical_devices() } {
285                Ok(devices) => {
286                    for device in devices {
287                        fetch_device(&instance, device)?;
288                    }
289                }
290                Err(e) => {
291                    eprintln!("Failed to enumerate physical devices: {:?}", e);
292                }
293            },
294            Err(e) => {
295                eprintln!("Failed to create instance: {:?}", e);
296                continue;
297            }
298        };
299
300        break;
301    }
302    Ok(())
303}
304
305#[cfg(windows)]
306mod vt {
307    use std::io::{Error, Result};
308    use winapi::um::consoleapi::{GetConsoleMode, SetConsoleMode};
309    use winapi::um::handleapi::INVALID_HANDLE_VALUE;
310    use winapi::um::processenv::GetStdHandle;
311    use winapi::um::winbase::STD_OUTPUT_HANDLE;
312    use winapi::um::wincon::ENABLE_VIRTUAL_TERMINAL_PROCESSING;
313
314    /// Enables Virtual Terminal Processing on Windows.
315    pub fn enable_virtual_terminal_processing() -> Result<()> {
316        unsafe {
317            let std_out = GetStdHandle(STD_OUTPUT_HANDLE);
318            if std_out == INVALID_HANDLE_VALUE {
319                return Err(Error::last_os_error());
320            }
321            let mut mode = 0;
322            if GetConsoleMode(std_out, &mut mode) == 0 {
323                return Err(Error::last_os_error());
324            }
325            mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
326            if SetConsoleMode(std_out, mode) == 0 {
327                return Err(Error::last_os_error());
328            }
329        }
330        Ok(())
331    }
332
333    /// Checks if Virtual Terminal Processing is enabled.
334    pub fn is_vt_enabled() -> bool {
335        unsafe {
336            let std_out = GetStdHandle(STD_OUTPUT_HANDLE);
337            if std_out == INVALID_HANDLE_VALUE {
338                return false;
339            }
340            let mut mode = 0;
341            if GetConsoleMode(std_out, &mut mode) == 0 {
342                return false;
343            }
344            (mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) != 0
345        }
346    }
347}
348
349#[cfg(not(windows))]
350mod vt {
351    use std::io::Result;
352
353    /// On non‑Windows platforms, VT processing is typically enabled by default.
354    pub fn enable_virtual_terminal_processing() -> Result<()> {
355        Ok(())
356    }
357
358    /// Assume ANSI escape codes are supported.
359    pub fn is_vt_enabled() -> bool {
360        true
361    }
362}
363
364/// Returns `true` if stdout is a TTY and (on Windows) VT processing is enabled.
365fn is_ansi_supported() -> bool {
366    std::io::IsTerminal::is_terminal(&std::io::stdout()) && vt::is_vt_enabled()
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::device::{Device, GPUCharacteristics};
373    use crate::vendor::Vendor;
374
375    /// For testing purposes we use the Unknown vendor variant.
376    impl Vendor {
377        pub fn dummy() -> Self {
378            Vendor::Unknown
379        }
380    }
381
382    /// Creates a dummy PhysicalDevice instance for tests.
383    fn dummy_physical_device() -> Device {
384        Device {
385            vendor: Vendor::dummy(),
386            device_name: "TestDevice".to_string(),
387            device_type: crate::device::DeviceType::DiscreteGPU,
388            device_id: 0xDEADBEEF,
389            vendor_id: 0xBEEF,
390            driver_name: "TestDriver".to_string(),
391            driver_info: "TestDriverInfo".to_string(),
392            api_version: "1.2.3.4".to_string(),
393            heapbudget: 8 * 1024 * 1024 * 1024, // 8 GB
394            heapsize: 10 * 1024 * 1024 * 1024,  // 10 GB
395            characteristics: GPUCharacteristics {
396                memory_pressure: 0.2, // 20%
397                compute_units: Some(10),
398                shader_engines: Some(2),
399                shader_arrays_per_engine_count: Some(2),
400                compute_units_per_shader_array: Some(5),
401                simd_per_compute_unit: Some(64),
402                wavefronts_per_simd: Some(4),
403                wavefront_size: Some(32),
404                streaming_multiprocessors: Some(46),
405                warps_per_sm: Some(32),
406                max_image_dimension_2d: 16384,
407                max_compute_shared_memory_size: 65536,
408                max_compute_work_group_invocations: 1024,
409                dedicated_transfer_queue: true,
410                dedicated_async_compute_queue: true,
411                supports_ray_tracing: true,
412            },
413        }
414    }
415
416    #[test]
417    fn test_format_bytes() {
418        assert_eq!(format_bytes(500), "500 B");
419        assert_eq!(format_bytes(1024), "1.000 KB");
420        assert_eq!(format_bytes(1024 * 1024), "1.000 MB");
421        assert_eq!(format_bytes(1024 * 1024 * 1024), "1.000 GB");
422        assert_eq!(format_bytes(1024 * 1024 * 1024 * 1024), "1.000 TB");
423    }
424
425    #[test]
426    fn test_get_device_info() {
427        let device = dummy_physical_device();
428        let color = "\x1B[32m";
429        let info = get_device_info(&device, color);
430        assert!(info.len() >= 9);
431        assert!(info[0].contains("TestDevice"));
432        assert!(info[0].contains(device.device_type.name()));
433        assert!(info[2].contains("0xDEADBEEF"));
434        assert!(info[2].contains("0xBEEF"));
435        assert!(info[7].contains("10") || info[7].contains("N/A"));
436        assert!(info[8].contains("32") || info[8].contains("N/A"));
437    }
438}