Skip to main content

vyre_driver/
device_profile.rs

1//! Backend-neutral device capability profile.
2//!
3//! Concrete backend crates probe their native device/API surfaces and project
4//! them into this value object. Shared optimizer, validation, launch, and
5//! strategy code consume projections of this profile instead of carrying
6//! independent capability records that can drift.
7
8use vyre_foundation::optimizer::AdapterCaps;
9use vyre_foundation::validate;
10
11/// Quality class for backend timing data exposed through [`DeviceProfile`].
12#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13pub enum DeviceTimingQuality {
14    /// The backend reports host wall-clock timing only.
15    HostOnly,
16    /// The backend can split host enqueue and host wait timing, but not trusted device elapsed time.
17    HostEnqueueWait,
18    /// The backend can report device elapsed time through timestamp queries or events.
19    DeviceTimestamps,
20    /// The backend can report device elapsed time plus hardware counter samples.
21    HardwareCounters,
22}
23
24impl DeviceTimingQuality {
25    /// Stable report/config string for timing-quality evidence.
26    #[must_use]
27    pub const fn as_str(self) -> &'static str {
28        match self {
29            Self::HostOnly => "host_only",
30            Self::HostEnqueueWait => "host_enqueue_wait",
31            Self::DeviceTimestamps => "device_timestamps",
32            Self::HardwareCounters => "hardware_counters",
33        }
34    }
35}
36
37/// Device capability snapshot used across driver-shared planning.
38#[derive(Clone, Copy, Debug, Eq, PartialEq)]
39pub struct DeviceProfile {
40    /// Stable backend identifier.
41    pub backend: &'static str,
42    /// The device and lowering path support subgroup intrinsics.
43    pub supports_subgroup_ops: bool,
44    /// The backend supports indirect dispatch.
45    pub supports_indirect_dispatch: bool,
46    /// The backend lowers distributed collective communication nodes.
47    pub supports_distributed_collectives: bool,
48    /// The backend supports compile-time specialization constants.
49    pub supports_specialization_constants: bool,
50    /// The backend lowers binary16 natively.
51    pub supports_f16: bool,
52    /// The backend lowers bfloat16 natively.
53    pub supports_bf16: bool,
54    /// The backend preserves explicit trap propagation.
55    pub supports_trap_propagation: bool,
56    /// The backend lowers matrix-engine operations for supported shapes.
57    pub supports_tensor_cores: bool,
58    /// Native unsigned multiply-high is available to lowering strategies.
59    pub has_mul_high: bool,
60    /// Integer and float pipelines can issue concurrently.
61    pub has_dual_issue_fp32_int32: bool,
62    /// Subgroup shuffle-like communication is available.
63    pub has_subgroup_shuffle: bool,
64    /// Explicit workgroup/shared memory is available.
65    pub has_shared_memory: bool,
66    /// Maximum native integer width in bits.
67    pub max_native_int_width: u32,
68    /// Maximum workgroup dimensions.
69    pub max_workgroup_size: [u32; 3],
70    /// Maximum invocations in one workgroup.
71    pub max_invocations_per_workgroup: u32,
72    /// Shared memory per workgroup in bytes.
73    pub max_shared_memory_bytes: u32,
74    /// Maximum single storage-buffer binding in bytes.
75    pub max_storage_buffer_binding_size: u64,
76    /// Native subgroup size, or `0` when unknown.
77    pub subgroup_size: u32,
78    /// Physical compute-unit count, or `0` when unknown.
79    pub compute_units: u32,
80    /// Maximum registers per thread, or `0` when unknown.
81    pub regs_per_thread_max: u32,
82    /// L1 cache size in bytes, or `0` when unknown.
83    pub l1_cache_bytes: u32,
84    /// L2 cache size in bytes, or `0` when unknown.
85    pub l2_cache_bytes: u32,
86    /// Peak memory bandwidth in GB/s, or `0` when unknown.
87    pub mem_bw_gbps: u32,
88    /// Timing-data quality exposed by this backend/device.
89    pub timing_quality: DeviceTimingQuality,
90    /// Device timestamp queries/events are available for dispatch timing.
91    pub supports_device_timestamps: bool,
92    /// Hardware counter sampling is available for benchmark telemetry.
93    pub supports_hardware_counters: bool,
94    /// Device-profile preferred unroll depth, or `0` when unknown.
95    pub ideal_unroll_depth: u32,
96    /// Device-profile preferred vector pack width in bits, or `0` when unknown.
97    pub ideal_vector_pack_bits: u32,
98    /// Device-profile preferred workgroup tile, or `[0, 0, 0]` when unknown.
99    pub ideal_workgroup_tile: [u32; 3],
100    /// Shared-memory bank count, or `0` when unknown.
101    pub shared_memory_bank_count: u32,
102    /// Shared-memory bank width in bytes, or `0` when unknown.
103    pub shared_memory_bank_width_bytes: u32,
104}
105
106impl Default for DeviceProfile {
107    fn default() -> Self {
108        Self::conservative("unknown")
109    }
110}
111
112impl DeviceProfile {
113    /// Conservative profile for a backend that has not probed a device.
114    #[must_use]
115    pub const fn conservative(backend: &'static str) -> Self {
116        Self {
117            backend,
118            supports_subgroup_ops: false,
119            supports_indirect_dispatch: false,
120            supports_distributed_collectives: false,
121            supports_specialization_constants: false,
122            supports_f16: false,
123            supports_bf16: false,
124            supports_trap_propagation: false,
125            supports_tensor_cores: false,
126            has_mul_high: false,
127            has_dual_issue_fp32_int32: false,
128            has_subgroup_shuffle: false,
129            has_shared_memory: false,
130            max_native_int_width: 32,
131            max_workgroup_size: [1, 1, 1],
132            max_invocations_per_workgroup: 1,
133            max_shared_memory_bytes: 0,
134            max_storage_buffer_binding_size: 0,
135            subgroup_size: 0,
136            compute_units: 0,
137            regs_per_thread_max: 0,
138            l1_cache_bytes: 0,
139            l2_cache_bytes: 0,
140            mem_bw_gbps: 0,
141            timing_quality: DeviceTimingQuality::HostOnly,
142            supports_device_timestamps: false,
143            supports_hardware_counters: false,
144            ideal_unroll_depth: 0,
145            ideal_vector_pack_bits: 0,
146            ideal_workgroup_tile: [0, 0, 0],
147            shared_memory_bank_count: 0,
148            shared_memory_bank_width_bytes: 0,
149        }
150    }
151
152    /// Build a profile from the stable backend trait capability methods.
153    #[must_use]
154    pub fn from_backend(backend: &dyn crate::backend::VyreBackend) -> Self {
155        let max_workgroup_size = backend.max_workgroup_size();
156        Self {
157            backend: backend.id(),
158            supports_subgroup_ops: backend.supports_subgroup_ops(),
159            supports_indirect_dispatch: backend.supports_indirect_dispatch(),
160            supports_distributed_collectives: backend.supports_distributed_collectives(),
161            supports_specialization_constants: false,
162            supports_f16: backend.supports_f16(),
163            supports_bf16: backend.supports_bf16(),
164            supports_trap_propagation: false,
165            supports_tensor_cores: backend.supports_tensor_cores(),
166            has_mul_high: false,
167            has_dual_issue_fp32_int32: false,
168            has_subgroup_shuffle: backend.supports_subgroup_ops(),
169            has_shared_memory: false,
170            max_native_int_width: 32,
171            max_workgroup_size,
172            max_invocations_per_workgroup: backend.max_compute_invocations_per_workgroup(),
173            max_shared_memory_bytes: 0,
174            max_storage_buffer_binding_size: backend.max_storage_buffer_bytes(),
175            subgroup_size: backend.subgroup_size().unwrap_or(0),
176            compute_units: 0,
177            regs_per_thread_max: 0,
178            l1_cache_bytes: 0,
179            l2_cache_bytes: 0,
180            mem_bw_gbps: 0,
181            timing_quality: DeviceTimingQuality::HostOnly,
182            supports_device_timestamps: false,
183            supports_hardware_counters: false,
184            ideal_unroll_depth: 0,
185            ideal_vector_pack_bits: 0,
186            ideal_workgroup_tile: [0, 0, 0],
187            shared_memory_bank_count: 0,
188            shared_memory_bank_width_bytes: 0,
189        }
190    }
191
192    /// Validation capability projection.
193    #[must_use]
194    pub const fn validation_capabilities(self) -> validate::BackendCapabilities {
195        validate::BackendCapabilities {
196            supports_subgroup_ops: self.supports_subgroup_ops,
197            supports_indirect_dispatch: self.supports_indirect_dispatch,
198            supports_specialization_constants: self.supports_specialization_constants,
199            has_mul_high: self.has_mul_high,
200            has_dual_issue_fp32_int32: self.has_dual_issue_fp32_int32,
201            has_tensor_core_int: self.supports_tensor_cores,
202            has_native_f16: self.supports_f16,
203            has_warp_shuffle: self.has_subgroup_shuffle,
204            has_shared_memory: self.has_shared_memory,
205            has_transcendental_polynomial_emit: true,
206            supports_distributed_collectives: self.supports_distributed_collectives,
207            max_native_int_width: self.max_native_int_width,
208        }
209    }
210
211    /// Optimizer capability projection.
212    #[must_use]
213    pub const fn adapter_caps(self) -> AdapterCaps {
214        AdapterCaps {
215            backend: self.backend,
216            supports_subgroup_ops: self.supports_subgroup_ops,
217            supports_indirect_dispatch: self.supports_indirect_dispatch,
218            supports_specialization_constants: self.supports_specialization_constants,
219            max_workgroup_size: self.max_workgroup_size,
220            max_invocations_per_workgroup: self.max_invocations_per_workgroup,
221            max_shared_memory_bytes: self.max_shared_memory_bytes,
222            max_storage_buffer_binding_size: self.max_storage_buffer_binding_size,
223            subgroup_size: self.subgroup_size,
224            compute_units: self.compute_units,
225            regs_per_thread_max: self.regs_per_thread_max,
226            l1_cache_bytes: self.l1_cache_bytes,
227            l2_cache_bytes: self.l2_cache_bytes,
228            mem_bw_gbps: self.mem_bw_gbps,
229            ideal_unroll_depth: self.ideal_unroll_depth,
230            ideal_vector_pack_bits: self.ideal_vector_pack_bits,
231            ideal_workgroup_tile: self.ideal_workgroup_tile,
232            shared_memory_bank_count: self.shared_memory_bank_count,
233            shared_memory_bank_width_bytes: self.shared_memory_bank_width_bytes,
234        }
235    }
236
237    /// Strategy capability projection.
238    #[must_use]
239    pub const fn strategy_capabilities(self) -> validate::BackendCapabilities {
240        self.validation_capabilities()
241    }
242}
243
244impl From<DeviceProfile> for AdapterCaps {
245    #[inline]
246    fn from(profile: DeviceProfile) -> Self {
247        profile.adapter_caps()
248    }
249}
250
251impl From<DeviceProfile> for validate::BackendCapabilities {
252    #[inline]
253    fn from(profile: DeviceProfile) -> Self {
254        profile.validation_capabilities()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::{DeviceProfile, DeviceTimingQuality};
261
262    #[test]
263    fn timing_quality_has_stable_report_strings() {
264        assert_eq!(DeviceTimingQuality::HostOnly.as_str(), "host_only");
265        assert_eq!(
266            DeviceTimingQuality::HostEnqueueWait.as_str(),
267            "host_enqueue_wait"
268        );
269        assert_eq!(
270            DeviceTimingQuality::DeviceTimestamps.as_str(),
271            "device_timestamps"
272        );
273        assert_eq!(
274            DeviceTimingQuality::HardwareCounters.as_str(),
275            "hardware_counters"
276        );
277    }
278
279    #[test]
280    fn projections_share_the_same_feature_bits() {
281        let profile = DeviceProfile {
282            backend: "test",
283            supports_subgroup_ops: true,
284            supports_indirect_dispatch: true,
285            supports_distributed_collectives: true,
286            supports_specialization_constants: true,
287            supports_f16: true,
288            supports_bf16: false,
289            supports_trap_propagation: true,
290            supports_tensor_cores: true,
291            has_mul_high: true,
292            has_dual_issue_fp32_int32: true,
293            has_subgroup_shuffle: true,
294            has_shared_memory: true,
295            max_native_int_width: 64,
296            max_workgroup_size: [256, 1, 1],
297            max_invocations_per_workgroup: 256,
298            max_shared_memory_bytes: 48 * 1024,
299            max_storage_buffer_binding_size: 1 << 30,
300            subgroup_size: 32,
301            compute_units: 128,
302            regs_per_thread_max: 255,
303            l1_cache_bytes: 128 * 1024,
304            l2_cache_bytes: 64 * 1024 * 1024,
305            mem_bw_gbps: 1700,
306            timing_quality: super::DeviceTimingQuality::HardwareCounters,
307            supports_device_timestamps: true,
308            supports_hardware_counters: true,
309            ideal_unroll_depth: 8,
310            ideal_vector_pack_bits: 128,
311            ideal_workgroup_tile: [16, 16, 1],
312            shared_memory_bank_count: 32,
313            shared_memory_bank_width_bytes: 4,
314        };
315
316        let validation = profile.validation_capabilities();
317        let adapter = profile.adapter_caps();
318        let strategy = profile.strategy_capabilities();
319
320        assert!(validation.supports_subgroup_ops);
321        assert!(validation.supports_distributed_collectives);
322        assert!(adapter.supports_subgroup_ops);
323        assert!(strategy.has_warp_shuffle);
324        assert_eq!(adapter.max_invocations_per_workgroup, 256);
325        assert_eq!(adapter.ideal_unroll_depth, 8);
326        assert_eq!(adapter.ideal_vector_pack_bits, 128);
327        assert_eq!(adapter.ideal_workgroup_tile, [16, 16, 1]);
328        assert_eq!(strategy.max_native_int_width, 64);
329        assert_eq!(
330            profile.timing_quality,
331            super::DeviceTimingQuality::HardwareCounters
332        );
333        assert!(profile.supports_device_timestamps);
334        assert!(profile.supports_hardware_counters);
335    }
336}