Skip to main content

vyre_driver_cuda/
device.rs

1//! CUDA device probing and capability snapshots.
2
3use std::{fmt, sync::Arc};
4
5use cudarc::driver::{result, sys::CUdevice_attribute, CudaContext};
6
7use crate::backend::staging_reserve::reserved_vec;
8
9fn format_cuda_context_init_error(ordinal: usize, error: impl fmt::Display) -> String {
10    format!(
11        "CUDA context init failed for ordinal {ordinal}: {error}. Fix: choose a visible `nvidia-smi -L` ordinal and ensure no exclusive-process compute mode blocks context creation. If the error is CUDA_ERROR_OUT_OF_MEMORY, treat it as live VRAM pressure during context creation: run `nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv`, free or move the processes holding VRAM, then rerun the CUDA-required validation; do not skip GPU tests or continue on a CPU path."
12    )
13}
14
15#[cfg(test)]
16mod context_init_error_tests {
17    use super::format_cuda_context_init_error;
18
19    #[test]
20    fn context_init_oom_diagnostic_names_vram_pressure_without_cpu_escape() {
21        let diagnostic = format_cuda_context_init_error(0, "CUDA_ERROR_OUT_OF_MEMORY");
22        assert!(diagnostic.contains("CUDA_ERROR_OUT_OF_MEMORY"));
23        assert!(diagnostic.contains("live VRAM pressure during context creation"));
24        assert!(diagnostic
25            .contains("nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv"));
26        assert!(diagnostic.contains("do not skip GPU tests"));
27        assert!(diagnostic.contains("continue on a CPU path"));
28    }
29}
30
31/// Queried physical limits and capabilities of a CUDA GPU.
32#[derive(Debug, Clone)]
33pub struct CudaDeviceCaps {
34    /// The device vendor name.
35    pub name: String,
36    /// The physical device index.
37    pub ordinal: usize,
38    /// Hardware compute capability (major, minor).
39    pub compute_capability: (u32, u32),
40    /// Overall VRAM capacity in bytes.
41    pub total_memory: u64,
42    /// Maximum number of threads executable in one block.
43    pub max_threads_per_block: i32,
44    /// Maximum dimensions for a thread block (x, y, z).
45    pub max_block_dim: [i32; 3],
46    /// Maximum dimensions for a dispatch grid (x, y, z).
47    pub max_grid_dim: [i32; 3],
48    /// Shared memory available per thread block in bytes.
49    pub shared_memory_per_block: i32,
50    /// Shared memory available per streaming multiprocessor in bytes.
51    pub shared_memory_per_sm: i32,
52    /// Number of threads in a hardware warp.
53    pub warp_size: i32,
54    /// Whether the device supports cooperative grid launches (megakernel prerequisite).
55    pub cooperative_launch: bool,
56    /// Whether the device can run multiple kernels concurrently from different streams.
57    pub concurrent_kernels: bool,
58    /// Number of independent async copy engines available.
59    pub async_engine_count: i32,
60    /// Number of streaming multiprocessors. Used by runtime planners that
61    /// need to size concurrent graph-replay lanes against real hardware
62    /// width instead of a fixed host-side constant.
63    pub multi_processor_count: i32,
64    /// Device-wide L2 cache capacity in bytes.
65    pub l2_cache_bytes: i32,
66    /// Memory clock rate in kHz, as reported by the CUDA driver.
67    pub memory_clock_rate_khz: i32,
68    /// Global memory bus width in bits.
69    pub global_memory_bus_width_bits: i32,
70    /// Maximum 32-bit registers usable by a single thread block. Required
71    /// for occupancy-aware workgroup sizing (I4)  -  when ptxas reports a
72    /// kernel's per-thread register pressure, this caps the largest block
73    /// the driver can launch without spill.
74    pub max_registers_per_block: i32,
75    /// Maximum 32-bit registers available per streaming multiprocessor.
76    /// Combined with kernel register pressure this gives the per-SM block
77    /// concurrency limit for the I4 occupancy estimator.
78    pub max_registers_per_sm: i32,
79    /// Maximum threads resident on a streaming multiprocessor.
80    /// `max_threads_per_sm / workgroup_size` is the upper bound on
81    /// concurrent blocks per SM before register or shared-memory limits
82    /// kick in.
83    pub max_threads_per_sm: i32,
84}
85
86/// Centralized live CUDA device acquisition result.
87#[derive(Debug, Clone)]
88pub struct CudaDeviceHandle {
89    /// Probed capabilities for the acquired device.
90    pub caps: CudaDeviceCaps,
91    /// Bound CUDA context for dispatch.
92    pub ctx: Arc<CudaContext>,
93}
94
95impl CudaDeviceHandle {
96    /// Acquire and bind a CUDA context for `ordinal`, returning the matching
97    /// capability snapshot from the same CUDA device handle.
98    ///
99    /// # Errors
100    ///
101    /// Returns an actionable error when the CUDA driver cannot initialize, the
102    /// ordinal is invalid, context creation fails, context binding fails, or a
103    /// required device attribute cannot be queried.
104    pub fn acquire_ordinal(ordinal: usize) -> Result<Self, String> {
105        let device_count = CudaDeviceCaps::visible_device_count()?;
106        if ordinal >= device_count {
107            return Err(format!(
108                "CUDA device ordinal {ordinal} is out of range for {device_count} visible device(s). Fix: select a CUDA device ordinal reported by `nvidia-smi`."
109            ));
110        }
111
112        let ctx = CudaContext::new(ordinal)
113            .map_err(|error| format_cuda_context_init_error(ordinal, error))?;
114        ctx.bind_to_thread().map_err(|e| {
115            format!(
116                "CUDA context bind failed for ordinal {ordinal}: {e}. Fix: repair CUDA context ownership before dispatch; GPU-required runs must not continue with an unbound context."
117            )
118        })?;
119        let caps = CudaDeviceCaps::probe_context(ordinal, &ctx)?;
120        Ok(Self { caps, ctx })
121    }
122}
123
124impl CudaDeviceCaps {
125    fn required_u32_capability(&self, name: &str, value: i32) -> u32 {
126        debug_assert!(
127            value > 0,
128            "CUDA device `{}` carried invalid {name}={value} after capability validation",
129            self.name
130        );
131        if value <= 0 {
132            tracing::error!(
133                "CUDA device `{}` carried invalid {name}={value} after capability validation. Fix: reject corrupt capability snapshots during probe.",
134                self.name
135            );
136            return 0;
137        }
138        u32::try_from(value).unwrap_or_else(|source| {
139            tracing::error!(
140                "CUDA device `{}` carried non-u32 {name}={value} after capability validation: {source}. Fix: reject corrupt capability snapshots during probe.",
141                self.name
142            );
143            0
144        })
145    }
146
147    /// Return the number of CUDA devices visible to the CUDA driver.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error when the CUDA driver cannot initialize or report its
152    /// visible device count.
153    pub fn visible_device_count() -> Result<usize, String> {
154        result::init().map_err(|e| {
155            format!(
156                "CUDA driver init failed: {e}. Fix: verify `nvidia-smi` succeeds and libcuda.so from the NVIDIA driver is visible to this process."
157            )
158        })?;
159        let count = result::device::get_count()
160            .map_err(|e| {
161                format!(
162                    "CUDA device-count query failed: {e}. Fix: repair CUDA driver/device visibility; a GPU-required host must not report zero devices."
163                )
164            })?;
165        usize::try_from(count)
166            .map_err(|_| format!("CUDA device-count query returned negative value {count}"))
167    }
168
169    /// Probe every CUDA device visible to the process.
170    ///
171    /// # Errors
172    ///
173    /// Returns an actionable error when any visible device cannot be probed.
174    pub fn probe_all() -> Result<Vec<Self>, String> {
175        let device_count = Self::visible_device_count()?;
176        if device_count == 0 {
177            return Err(
178                "CUDA device-count query returned zero visible devices. Fix: this is a GPU-required release host; run `nvidia-smi -L`, repair CUDA_VISIBLE_DEVICES/container GPU passthrough, and do not silently continue on a CPU path."
179                    .to_string(),
180            );
181        }
182        let mut devices = reserved_vec(device_count, "cuda visible device probes")
183            .map_err(|error| error.to_string())?;
184        for ordinal in 0..device_count {
185            devices.push(Self::probe(ordinal)?);
186        }
187        Ok(devices)
188    }
189
190    /// Probe the device using the raw CUDA driver API.
191    ///
192    /// # Errors
193    ///
194    /// Returns an error when the CUDA driver cannot initialize, the ordinal is
195    /// out of range, or a required device attribute cannot be queried.
196    pub fn probe(ordinal: usize) -> Result<Self, String> {
197        let device_count = Self::visible_device_count()?;
198        if ordinal >= device_count {
199            return Err(format!(
200                "CUDA device ordinal {ordinal} is out of range for {device_count} visible device(s). Fix: select a CUDA device ordinal reported by `nvidia-smi`."
201            ));
202        }
203
204        let ctx = CudaContext::new(ordinal)
205            .map_err(|error| format_cuda_context_init_error(ordinal, error))?;
206        Self::probe_context(ordinal, &ctx)
207    }
208
209    fn probe_context(ordinal: usize, ctx: &CudaContext) -> Result<Self, String> {
210        let dev = ctx.cu_device();
211
212        let attr = |name: &str, attrib| {
213            // SAFETY: cuDeviceGetCount / cuDeviceGet operate on raw pointers we own on
214            // the current thread; the call returns CUresult and is wrapped in cuda_check.
215            unsafe { result::device::get_attribute(dev, attrib) }
216                .map_err(|e| format!("CUDA attribute query `{name}` failed: {e}"))
217        };
218
219        // SAFETY: cuDeviceGetCount / cuDeviceGet operate on raw pointers we own on
220        // the current thread; the call returns CUresult and is wrapped in cuda_check.
221        let total_memory = unsafe { result::device::total_mem(dev) }
222            .map_err(|e| format!("CUDA total-memory query failed: {e}"))?;
223        let name = result::device::get_name(dev)
224            .map_err(|e| format!("CUDA device-name query failed: {e}"))?;
225        let major = attr(
226            "compute_capability_major",
227            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
228        )?;
229        let minor = attr(
230            "compute_capability_minor",
231            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
232        )?;
233        let max_threads_per_block = attr(
234            "max_threads_per_block",
235            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
236        )?;
237        let max_block_dim_x = attr(
238            "max_block_dim_x",
239            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X,
240        )?;
241        let max_block_dim_y = attr(
242            "max_block_dim_y",
243            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y,
244        )?;
245        let max_block_dim_z = attr(
246            "max_block_dim_z",
247            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z,
248        )?;
249        let max_grid_dim_x = attr(
250            "max_grid_dim_x",
251            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
252        )?;
253        let max_grid_dim_y = attr(
254            "max_grid_dim_y",
255            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y,
256        )?;
257        let max_grid_dim_z = attr(
258            "max_grid_dim_z",
259            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z,
260        )?;
261        let shared_memory_per_block = attr(
262            "shared_memory_per_block",
263            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
264        )?;
265        let shared_memory_per_sm = attr(
266            "shared_memory_per_sm",
267            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
268        )?;
269        let warp_size = attr(
270            "warp_size",
271            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_WARP_SIZE,
272        )?;
273        let cooperative_launch = attr(
274            "cooperative_launch",
275            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH,
276        )?;
277        let concurrent_kernels = attr(
278            "concurrent_kernels",
279            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS,
280        )?;
281        let async_engine_count = attr(
282            "async_engine_count",
283            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT,
284        )?;
285        let multi_processor_count = attr(
286            "multi_processor_count",
287            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
288        )?;
289        let l2_cache_bytes = attr(
290            "l2_cache_bytes",
291            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE,
292        )?;
293        let memory_clock_rate_khz = attr(
294            "memory_clock_rate_khz",
295            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE,
296        )?;
297        let global_memory_bus_width_bits = attr(
298            "global_memory_bus_width_bits",
299            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH,
300        )?;
301        let max_registers_per_block = attr(
302            "max_registers_per_block",
303            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
304        )?;
305        let max_registers_per_sm = attr(
306            "max_registers_per_sm",
307            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR,
308        )?;
309        let max_threads_per_sm = attr(
310            "max_threads_per_sm",
311            CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR,
312        )?;
313
314        let caps = Self {
315            name,
316            ordinal,
317            compute_capability: (
318                u32::try_from(major).map_err(|source| {
319                    format!("CUDA device major compute capability was negative ({major}): {source}. Fix: repair CUDA driver attribute probing before dispatch.")
320                })?,
321                u32::try_from(minor).map_err(|source| {
322                    format!("CUDA device minor compute capability was negative ({minor}): {source}. Fix: repair CUDA driver attribute probing before dispatch.")
323                })?,
324            ),
325            total_memory: u64::try_from(total_memory).map_err(|source| {
326                format!("CUDA total memory value {total_memory} does not fit u64: {source}. Fix: widen CudaDeviceCaps memory telemetry before dispatch.")
327            })?,
328            max_threads_per_block,
329            max_block_dim: [max_block_dim_x, max_block_dim_y, max_block_dim_z],
330            max_grid_dim: [max_grid_dim_x, max_grid_dim_y, max_grid_dim_z],
331            shared_memory_per_block,
332            shared_memory_per_sm,
333            warp_size,
334            cooperative_launch: cooperative_launch != 0,
335            concurrent_kernels: concurrent_kernels != 0,
336            async_engine_count,
337            multi_processor_count,
338            l2_cache_bytes,
339            memory_clock_rate_khz,
340            global_memory_bus_width_bits,
341            max_registers_per_block,
342            max_registers_per_sm,
343            max_threads_per_sm,
344        };
345        caps.validate_required_attributes()?;
346        Ok(caps)
347    }
348
349    fn validate_required_attributes(&self) -> Result<(), String> {
350        if self.name.trim().is_empty() {
351            return Err(format!(
352                "CUDA device ordinal {} returned an empty device name. Fix: repair CUDA driver probing before capability-dependent dispatch.",
353                self.ordinal
354            ));
355        }
356        if self.compute_capability.0 == 0 {
357            return Err(format!(
358                "CUDA device `{}` returned invalid compute capability {:?}. Fix: update the NVIDIA driver so CUDA attributes report a real SM target.",
359                self.name, self.compute_capability
360            ));
361        }
362        if self.total_memory == 0 {
363            return Err(format!(
364                "CUDA device `{}` reported zero total memory. Fix: repair CUDA device visibility; do not continue with bogus memory limits.",
365                self.name
366            ));
367        }
368        for (name, value) in [
369            ("max_threads_per_block", self.max_threads_per_block),
370            ("max_block_dim_x", self.max_block_dim[0]),
371            ("max_block_dim_y", self.max_block_dim[1]),
372            ("max_block_dim_z", self.max_block_dim[2]),
373            ("max_grid_dim_x", self.max_grid_dim[0]),
374            ("max_grid_dim_y", self.max_grid_dim[1]),
375            ("max_grid_dim_z", self.max_grid_dim[2]),
376            ("shared_memory_per_block", self.shared_memory_per_block),
377            ("shared_memory_per_sm", self.shared_memory_per_sm),
378            ("warp_size", self.warp_size),
379            ("multi_processor_count", self.multi_processor_count),
380            ("l2_cache_bytes", self.l2_cache_bytes),
381            ("memory_clock_rate_khz", self.memory_clock_rate_khz),
382            (
383                "global_memory_bus_width_bits",
384                self.global_memory_bus_width_bits,
385            ),
386            ("max_registers_per_block", self.max_registers_per_block),
387            ("max_registers_per_sm", self.max_registers_per_sm),
388            ("max_threads_per_sm", self.max_threads_per_sm),
389        ] {
390            if value <= 0 {
391                return Err(format!(
392                    "CUDA device `{}` reported invalid {name}={value}. Fix: repair CUDA capability probing before dispatch; zero/negative limits are a hard GPU configuration error.",
393                    self.name
394                ));
395            }
396        }
397        Ok(())
398    }
399
400    /// Native CUDA SM number reported by the device compute capability.
401    #[must_use]
402    pub fn native_sm(&self) -> u32 {
403        self.compute_capability.0 * 10 + self.compute_capability.1
404    }
405
406    /// PTX `.target sm_XX` selected for this device.
407    ///
408    /// The CUDA driver JIT accepts virtual PTX for the current architecture.
409    /// Capping this value below the live device hides architecture-specific
410    /// scheduling and invalidates cache keys across GPU generations.
411    #[must_use]
412    pub fn ptx_target_sm(&self) -> u32 {
413        self.native_sm()
414    }
415
416    /// Shared memory available per CUDA thread block in bytes.
417    #[must_use]
418    pub fn shared_memory_per_block_bytes(&self) -> u32 {
419        self.required_u32_capability("shared_memory_per_block", self.shared_memory_per_block)
420    }
421
422    /// Shared memory available per CUDA streaming multiprocessor in bytes.
423    #[must_use]
424    pub fn shared_memory_per_sm_bytes(&self) -> u32 {
425        self.required_u32_capability("shared_memory_per_sm", self.shared_memory_per_sm)
426    }
427
428    /// Maximum threads per block as an unsigned launch-limit value.
429    #[must_use]
430    pub fn max_threads_per_block_u32(&self) -> u32 {
431        self.required_u32_capability("max_threads_per_block", self.max_threads_per_block)
432    }
433
434    /// Maximum 32-bit registers per CUDA thread block.
435    #[must_use]
436    pub fn max_registers_per_block_u32(&self) -> u32 {
437        self.required_u32_capability("max_registers_per_block", self.max_registers_per_block)
438    }
439
440    /// Maximum 32-bit registers per streaming multiprocessor.
441    #[must_use]
442    pub fn max_registers_per_sm_u32(&self) -> u32 {
443        self.required_u32_capability("max_registers_per_sm", self.max_registers_per_sm)
444    }
445
446    /// Maximum resident threads per streaming multiprocessor.
447    #[must_use]
448    pub fn max_threads_per_sm_u32(&self) -> u32 {
449        self.required_u32_capability("max_threads_per_sm", self.max_threads_per_sm)
450    }
451
452    /// Number of streaming multiprocessors as an unsigned runtime-planning value.
453    #[must_use]
454    pub fn multi_processor_count_u32(&self) -> u32 {
455        self.required_u32_capability("multi_processor_count", self.multi_processor_count)
456    }
457
458    /// Device-wide L2 cache capacity in bytes.
459    #[must_use]
460    pub fn l2_cache_bytes_u32(&self) -> u32 {
461        self.required_u32_capability("l2_cache_bytes", self.l2_cache_bytes)
462    }
463
464    /// Approximate peak global-memory bandwidth in decimal GB/s.
465    #[must_use]
466    pub fn memory_bandwidth_gbps(&self) -> u32 {
467        let clock_khz =
468            self.required_u32_capability("memory_clock_rate_khz", self.memory_clock_rate_khz);
469        let bus_bits = self.required_u32_capability(
470            "global_memory_bus_width_bits",
471            self.global_memory_bus_width_bits,
472        );
473        let gbps = (u64::from(clock_khz) * u64::from(bus_bits)) / 4_000_000;
474        u32::try_from(gbps.max(1)).unwrap_or_else(|source| {
475            tracing::error!(
476                "CUDA device `{}` memory bandwidth {gbps} GB/s does not fit u32: {source}. Fix: normalize the bandwidth model before exporting device profile telemetry.",
477                self.name
478            );
479            u32::MAX
480        })
481    }
482
483    /// NVIDIA CUDA architectural register ceiling per thread.
484    #[must_use]
485    pub fn max_registers_per_thread_u32(&self) -> u32 {
486        self.max_registers_per_block_u32().min(255)
487    }
488
489    /// Per-axis block limits as unsigned launch-limit values.
490    #[must_use]
491    pub fn max_block_dim_u32(&self) -> [u32; 3] {
492        self.max_block_dim
493            .map(|value| self.required_u32_capability("max_block_dim axis", value))
494    }
495
496    /// Per-axis grid limits as unsigned launch-limit values.
497    #[must_use]
498    pub fn max_grid_dim_u32(&self) -> [u32; 3] {
499        self.max_grid_dim
500            .map(|value| self.required_u32_capability("max_grid_dim axis", value))
501    }
502
503    /// Warp width reported by the CUDA device.
504    #[must_use]
505    pub fn warp_size_u32(&self) -> Option<u32> {
506        Some(self.required_warp_size_u32())
507    }
508
509    /// Warp width reported by the CUDA device after probe-time validation.
510    #[must_use]
511    pub fn required_warp_size_u32(&self) -> u32 {
512        self.required_u32_capability("warp_size", self.warp_size)
513    }
514
515    /// Whether this device generation has native fp16 instructions.
516    #[must_use]
517    pub fn hardware_supports_f16(&self) -> bool {
518        self.compute_capability >= (5, 3)
519    }
520
521    /// Whether this device generation has native bf16 instructions.
522    #[must_use]
523    pub fn hardware_supports_bf16(&self) -> bool {
524        self.compute_capability >= (8, 0)
525    }
526
527    /// Whether this device generation exposes NVIDIA tensor-core instructions.
528    #[must_use]
529    pub fn hardware_supports_tensor_cores(&self) -> bool {
530        self.compute_capability >= (7, 0)
531    }
532
533    /// Project a CUDA device snapshot into the workspace-wide
534    /// [`vyre_foundation::optimizer::AdapterCaps`] (audit P0 #60). All vyre
535    /// backends consume the same typed capability shape so passes that
536    /// adapt to subgroup-ops, indirect dispatch, max workgroup size, or
537    /// shared-memory budget take a single typed input regardless of
538    /// backend identity.
539    ///
540    /// Mapping notes:
541    /// - `supports_subgroup_ops`: CUDA always supports warp shuffles
542    ///   (`__shfl_*`) on every supported architecture (compute capability
543    ///   ≥ 3.0), so this is `true`.
544    /// - `supports_indirect_dispatch`: CUDA exposes
545    ///   `cuLaunchKernelEx` and `cuLaunchCooperativeKernel` with
546    ///   indirect launch parameters; `true` when cooperative launch is
547    ///   reported (the megakernel prerequisite that exercises this).
548    /// - `supports_specialization_constants`: CUDA uses runtime kernel
549    ///   parameters rather than pipeline-creation specialization constants;
550    ///   surfaced as `false`.
551    /// - `subgroup_size`: warp size (32 on every shipping NVIDIA GPU,
552    ///   but probed live so future architectures stay correct).
553    #[must_use]
554    pub fn to_adapter_caps(&self) -> vyre_foundation::optimizer::AdapterCaps {
555        self.to_device_profile().into()
556    }
557
558    /// Project the probed device into the neutral driver profile.
559    #[must_use]
560    pub fn to_device_profile(&self) -> vyre_driver::DeviceProfile {
561        let subgroup = self.subgroup_caps();
562        let profile = vyre_driver::DeviceProfile {
563            backend: "cuda",
564            supports_subgroup_ops: subgroup.supports_subgroup,
565            supports_indirect_dispatch: self.cooperative_launch,
566            supports_distributed_collectives: false,
567            supports_specialization_constants: false,
568            supports_f16: self.hardware_supports_f16(),
569            supports_bf16: self.hardware_supports_bf16(),
570            supports_trap_propagation: true,
571            supports_tensor_cores: self.hardware_supports_tensor_cores(),
572            has_mul_high: true,
573            has_dual_issue_fp32_int32: true,
574            has_subgroup_shuffle: subgroup.supports_subgroup,
575            has_shared_memory: self.shared_memory_per_block_bytes() > 0,
576            max_native_int_width: 64,
577            max_workgroup_size: self.max_block_dim_u32(),
578            max_invocations_per_workgroup: self.max_threads_per_block_u32(),
579            max_shared_memory_bytes: self.shared_memory_per_block_bytes(),
580            max_storage_buffer_binding_size: self.total_memory,
581            subgroup_size: subgroup.subgroup_size,
582            compute_units: self.multi_processor_count_u32(),
583            regs_per_thread_max: self.max_registers_per_thread_u32(),
584            l1_cache_bytes: 0,
585            l2_cache_bytes: self.l2_cache_bytes_u32(),
586            mem_bw_gbps: self.memory_bandwidth_gbps(),
587            ideal_unroll_depth: 0,
588            ideal_vector_pack_bits: 0,
589            ideal_workgroup_tile: [0, 0, 0],
590            shared_memory_bank_count: 32,
591            shared_memory_bank_width_bytes: 4,
592        };
593        vyre_driver::DeviceSignatureTable::builtins().map_or(profile, |table| {
594            table.apply_generation_to_profile(self.native_sm(), profile)
595        })
596    }
597
598    /// Project CUDA warp capabilities into the shared subgroup record.
599    #[must_use]
600    pub fn subgroup_caps(&self) -> vyre_driver::SubgroupCaps {
601        vyre_driver::SubgroupCaps::native(self.required_warp_size_u32())
602    }
603}
604
605#[cfg(test)]
606
607mod tests {
608    use crate::synthetic_device_caps::blackwell_sm120_caps_default;
609
610    #[test]
611    fn cuda_profile_applies_builtin_sm_signature() {
612        let profile = blackwell_sm120_caps_default().to_device_profile();
613        let table =
614            vyre_driver::DeviceSignatureTable::builtins().expect("Fix: builtin signatures load");
615        let signature = table
616            .find_architecture_generation(120)
617            .expect("Fix: SM_120 must match the builtin Blackwell signature");
618
619        assert_eq!(profile.compute_units, 170);
620        assert_eq!(profile.ideal_unroll_depth, signature.ideal_unroll_depth);
621        assert_eq!(
622            profile.ideal_vector_pack_bits,
623            signature.ideal_vector_pack_bits
624        );
625        assert_eq!(profile.ideal_workgroup_tile, signature.ideal_workgroup_tile);
626        assert_eq!(profile.shared_memory_bank_count, signature.bank_count);
627    }
628
629    #[test]
630    fn cuda_profile_preserves_probed_compute_units_without_builtin_signature() {
631        let mut caps = blackwell_sm120_caps_default();
632        caps.compute_capability = (99, 0);
633        caps.multi_processor_count = 13;
634
635        let profile = caps.to_device_profile();
636
637        assert_eq!(profile.compute_units, 13);
638        assert_eq!(profile.regs_per_thread_max, 255);
639        assert_eq!(profile.l2_cache_bytes, 96 * 1024 * 1024);
640        assert_eq!(profile.mem_bw_gbps, 1792);
641        assert_eq!(profile.max_invocations_per_workgroup, 1024);
642        assert_eq!(profile.max_shared_memory_bytes, 128 * 1024);
643        assert_eq!(caps.shared_memory_per_sm_bytes(), 256 * 1024);
644        assert_eq!(profile.shared_memory_bank_count, 32);
645        assert_eq!(profile.shared_memory_bank_width_bytes, 4);
646    }
647
648    #[test]
649    fn cuda_probe_all_rejects_zero_device_silent_fallback_by_contract() {
650        let source = include_str!("device.rs");
651
652        assert!(
653            source.contains("device_count == 0")
654                && source.contains("do not silently continue on a CPU path"),
655            "Fix: CUDA device discovery must fail loudly when a GPU-required host reports zero visible devices."
656        );
657        assert!(
658            !source.contains(concat!("(0..device_count)", ".map(Self::probe).collect()")),
659            "Fix: CUDA device discovery must not hide zero devices behind an empty successful probe list."
660        );
661    }
662
663    #[test]
664    fn cuda_capability_conversion_has_no_production_panic_path() {
665        let source = include_str!("device.rs");
666        let start = source
667            .find("fn required_u32_capability")
668            .expect("Fix: CUDA capability conversion helper must exist");
669        let end = source[start..]
670            .find("/// Return the number of CUDA devices visible")
671            .expect("Fix: CUDA capability conversion helper must stay before device discovery")
672            + start;
673        let helper = &source[start..end];
674
675        assert!(
676            !helper.contains("panic!("),
677            "Fix: CUDA capability accessors must not abort production dispatch; probe-time validation must reject invalid capability snapshots with typed errors."
678        );
679        assert!(
680            !helper.contains("unwrap_or(1)"),
681            "Fix: CUDA capability accessors must not manufacture fake nonzero defaults after validation."
682        );
683        assert!(
684            !helper.contains(" as u32"),
685            "Fix: CUDA capability accessors must use checked integer conversion, not release-path narrowing casts."
686        );
687        let bandwidth_start = source
688            .find("pub fn memory_bandwidth_gbps")
689            .expect("Fix: CUDA memory bandwidth helper must exist");
690        let bandwidth_end = source[bandwidth_start..]
691            .find("/// NVIDIA CUDA architectural register ceiling")
692            .expect("Fix: CUDA memory bandwidth helper should precede register helper")
693            + bandwidth_start;
694        let bandwidth_helper = &source[bandwidth_start..bandwidth_end];
695        assert!(
696            !bandwidth_helper.contains(" as u32"),
697            "Fix: CUDA bandwidth telemetry must not narrow with an unchecked cast."
698        );
699        assert!(
700            bandwidth_helper.contains("u32::try_from"),
701            "Fix: CUDA bandwidth telemetry must use checked conversion after widened arithmetic."
702        );
703        assert!(
704            source.contains("caps.validate_required_attributes()?"),
705            "Fix: CUDA capability probing must validate launch-critical values before exposing infallible accessors."
706        );
707        assert!(
708            source.contains("CUDA device major compute capability was negative")
709                && source.contains("map_err(|source|"),
710            "Fix: CUDA capability probe must return typed errors for corrupt driver attributes instead of panicking."
711        );
712    }
713}
714