1use vyre_foundation::optimizer::AdapterCaps;
9use vyre_foundation::validate;
10
11#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13pub enum DeviceTimingQuality {
14 HostOnly,
16 HostEnqueueWait,
18 DeviceTimestamps,
20 HardwareCounters,
22}
23
24impl DeviceTimingQuality {
25 #[must_use]
27 pub const fn as_str(self) -> &'static str {
28 match self {
29 Self::HostOnly => "host_only",
30 Self::HostEnqueueWait => "host_enqueue_wait",
31 Self::DeviceTimestamps => "device_timestamps",
32 Self::HardwareCounters => "hardware_counters",
33 }
34 }
35}
36
37#[derive(Clone, Copy, Debug, Eq, PartialEq)]
39pub struct DeviceProfile {
40 pub backend: &'static str,
42 pub supports_subgroup_ops: bool,
44 pub supports_indirect_dispatch: bool,
46 pub supports_distributed_collectives: bool,
48 pub supports_specialization_constants: bool,
50 pub supports_f16: bool,
52 pub supports_bf16: bool,
54 pub supports_trap_propagation: bool,
56 pub supports_tensor_cores: bool,
58 pub has_mul_high: bool,
60 pub has_dual_issue_fp32_int32: bool,
62 pub has_subgroup_shuffle: bool,
64 pub has_shared_memory: bool,
66 pub max_native_int_width: u32,
68 pub max_workgroup_size: [u32; 3],
70 pub max_invocations_per_workgroup: u32,
72 pub max_shared_memory_bytes: u32,
74 pub max_storage_buffer_binding_size: u64,
76 pub subgroup_size: u32,
78 pub compute_units: u32,
80 pub regs_per_thread_max: u32,
82 pub l1_cache_bytes: u32,
84 pub l2_cache_bytes: u32,
86 pub mem_bw_gbps: u32,
88 pub timing_quality: DeviceTimingQuality,
90 pub supports_device_timestamps: bool,
92 pub supports_hardware_counters: bool,
94 pub ideal_unroll_depth: u32,
96 pub ideal_vector_pack_bits: u32,
98 pub ideal_workgroup_tile: [u32; 3],
100 pub shared_memory_bank_count: u32,
102 pub shared_memory_bank_width_bytes: u32,
104}
105
106impl Default for DeviceProfile {
107 fn default() -> Self {
108 Self::conservative("unknown")
109 }
110}
111
112impl DeviceProfile {
113 #[must_use]
115 pub const fn conservative(backend: &'static str) -> Self {
116 Self {
117 backend,
118 supports_subgroup_ops: false,
119 supports_indirect_dispatch: false,
120 supports_distributed_collectives: false,
121 supports_specialization_constants: false,
122 supports_f16: false,
123 supports_bf16: false,
124 supports_trap_propagation: false,
125 supports_tensor_cores: false,
126 has_mul_high: false,
127 has_dual_issue_fp32_int32: false,
128 has_subgroup_shuffle: false,
129 has_shared_memory: false,
130 max_native_int_width: 32,
131 max_workgroup_size: [1, 1, 1],
132 max_invocations_per_workgroup: 1,
133 max_shared_memory_bytes: 0,
134 max_storage_buffer_binding_size: 0,
135 subgroup_size: 0,
136 compute_units: 0,
137 regs_per_thread_max: 0,
138 l1_cache_bytes: 0,
139 l2_cache_bytes: 0,
140 mem_bw_gbps: 0,
141 timing_quality: DeviceTimingQuality::HostOnly,
142 supports_device_timestamps: false,
143 supports_hardware_counters: false,
144 ideal_unroll_depth: 0,
145 ideal_vector_pack_bits: 0,
146 ideal_workgroup_tile: [0, 0, 0],
147 shared_memory_bank_count: 0,
148 shared_memory_bank_width_bytes: 0,
149 }
150 }
151
152 #[must_use]
154 pub fn from_backend(backend: &dyn crate::backend::VyreBackend) -> Self {
155 let max_workgroup_size = backend.max_workgroup_size();
156 Self {
157 backend: backend.id(),
158 supports_subgroup_ops: backend.supports_subgroup_ops(),
159 supports_indirect_dispatch: backend.supports_indirect_dispatch(),
160 supports_distributed_collectives: backend.supports_distributed_collectives(),
161 supports_specialization_constants: false,
162 supports_f16: backend.supports_f16(),
163 supports_bf16: backend.supports_bf16(),
164 supports_trap_propagation: false,
165 supports_tensor_cores: backend.supports_tensor_cores(),
166 has_mul_high: false,
167 has_dual_issue_fp32_int32: false,
168 has_subgroup_shuffle: backend.supports_subgroup_ops(),
169 has_shared_memory: false,
170 max_native_int_width: 32,
171 max_workgroup_size,
172 max_invocations_per_workgroup: backend.max_compute_invocations_per_workgroup(),
173 max_shared_memory_bytes: 0,
174 max_storage_buffer_binding_size: backend.max_storage_buffer_bytes(),
175 subgroup_size: backend.subgroup_size().unwrap_or(0),
176 compute_units: 0,
177 regs_per_thread_max: 0,
178 l1_cache_bytes: 0,
179 l2_cache_bytes: 0,
180 mem_bw_gbps: 0,
181 timing_quality: DeviceTimingQuality::HostOnly,
182 supports_device_timestamps: false,
183 supports_hardware_counters: false,
184 ideal_unroll_depth: 0,
185 ideal_vector_pack_bits: 0,
186 ideal_workgroup_tile: [0, 0, 0],
187 shared_memory_bank_count: 0,
188 shared_memory_bank_width_bytes: 0,
189 }
190 }
191
192 #[must_use]
194 pub const fn validation_capabilities(self) -> validate::BackendCapabilities {
195 validate::BackendCapabilities {
196 supports_subgroup_ops: self.supports_subgroup_ops,
197 supports_indirect_dispatch: self.supports_indirect_dispatch,
198 supports_specialization_constants: self.supports_specialization_constants,
199 has_mul_high: self.has_mul_high,
200 has_dual_issue_fp32_int32: self.has_dual_issue_fp32_int32,
201 has_tensor_core_int: self.supports_tensor_cores,
202 has_native_f16: self.supports_f16,
203 has_warp_shuffle: self.has_subgroup_shuffle,
204 has_shared_memory: self.has_shared_memory,
205 has_transcendental_polynomial_emit: true,
206 supports_distributed_collectives: self.supports_distributed_collectives,
207 max_native_int_width: self.max_native_int_width,
208 }
209 }
210
211 #[must_use]
213 pub const fn adapter_caps(self) -> AdapterCaps {
214 AdapterCaps {
215 backend: self.backend,
216 supports_subgroup_ops: self.supports_subgroup_ops,
217 supports_indirect_dispatch: self.supports_indirect_dispatch,
218 supports_specialization_constants: self.supports_specialization_constants,
219 max_workgroup_size: self.max_workgroup_size,
220 max_invocations_per_workgroup: self.max_invocations_per_workgroup,
221 max_shared_memory_bytes: self.max_shared_memory_bytes,
222 max_storage_buffer_binding_size: self.max_storage_buffer_binding_size,
223 subgroup_size: self.subgroup_size,
224 compute_units: self.compute_units,
225 regs_per_thread_max: self.regs_per_thread_max,
226 l1_cache_bytes: self.l1_cache_bytes,
227 l2_cache_bytes: self.l2_cache_bytes,
228 mem_bw_gbps: self.mem_bw_gbps,
229 ideal_unroll_depth: self.ideal_unroll_depth,
230 ideal_vector_pack_bits: self.ideal_vector_pack_bits,
231 ideal_workgroup_tile: self.ideal_workgroup_tile,
232 shared_memory_bank_count: self.shared_memory_bank_count,
233 shared_memory_bank_width_bytes: self.shared_memory_bank_width_bytes,
234 }
235 }
236
237 #[must_use]
239 pub const fn strategy_capabilities(self) -> validate::BackendCapabilities {
240 self.validation_capabilities()
241 }
242}
243
244impl From<DeviceProfile> for AdapterCaps {
245 #[inline]
246 fn from(profile: DeviceProfile) -> Self {
247 profile.adapter_caps()
248 }
249}
250
251impl From<DeviceProfile> for validate::BackendCapabilities {
252 #[inline]
253 fn from(profile: DeviceProfile) -> Self {
254 profile.validation_capabilities()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::{DeviceProfile, DeviceTimingQuality};
261
262 #[test]
263 fn timing_quality_has_stable_report_strings() {
264 assert_eq!(DeviceTimingQuality::HostOnly.as_str(), "host_only");
265 assert_eq!(
266 DeviceTimingQuality::HostEnqueueWait.as_str(),
267 "host_enqueue_wait"
268 );
269 assert_eq!(
270 DeviceTimingQuality::DeviceTimestamps.as_str(),
271 "device_timestamps"
272 );
273 assert_eq!(
274 DeviceTimingQuality::HardwareCounters.as_str(),
275 "hardware_counters"
276 );
277 }
278
279 #[test]
280 fn projections_share_the_same_feature_bits() {
281 let profile = DeviceProfile {
282 backend: "test",
283 supports_subgroup_ops: true,
284 supports_indirect_dispatch: true,
285 supports_distributed_collectives: true,
286 supports_specialization_constants: true,
287 supports_f16: true,
288 supports_bf16: false,
289 supports_trap_propagation: true,
290 supports_tensor_cores: true,
291 has_mul_high: true,
292 has_dual_issue_fp32_int32: true,
293 has_subgroup_shuffle: true,
294 has_shared_memory: true,
295 max_native_int_width: 64,
296 max_workgroup_size: [256, 1, 1],
297 max_invocations_per_workgroup: 256,
298 max_shared_memory_bytes: 48 * 1024,
299 max_storage_buffer_binding_size: 1 << 30,
300 subgroup_size: 32,
301 compute_units: 128,
302 regs_per_thread_max: 255,
303 l1_cache_bytes: 128 * 1024,
304 l2_cache_bytes: 64 * 1024 * 1024,
305 mem_bw_gbps: 1700,
306 timing_quality: super::DeviceTimingQuality::HardwareCounters,
307 supports_device_timestamps: true,
308 supports_hardware_counters: true,
309 ideal_unroll_depth: 8,
310 ideal_vector_pack_bits: 128,
311 ideal_workgroup_tile: [16, 16, 1],
312 shared_memory_bank_count: 32,
313 shared_memory_bank_width_bytes: 4,
314 };
315
316 let validation = profile.validation_capabilities();
317 let adapter = profile.adapter_caps();
318 let strategy = profile.strategy_capabilities();
319
320 assert!(validation.supports_subgroup_ops);
321 assert!(validation.supports_distributed_collectives);
322 assert!(adapter.supports_subgroup_ops);
323 assert!(strategy.has_warp_shuffle);
324 assert_eq!(adapter.max_invocations_per_workgroup, 256);
325 assert_eq!(adapter.ideal_unroll_depth, 8);
326 assert_eq!(adapter.ideal_vector_pack_bits, 128);
327 assert_eq!(adapter.ideal_workgroup_tile, [16, 16, 1]);
328 assert_eq!(strategy.max_native_int_width, 64);
329 assert_eq!(
330 profile.timing_quality,
331 super::DeviceTimingQuality::HardwareCounters
332 );
333 assert!(profile.supports_device_timestamps);
334 assert!(profile.supports_hardware_counters);
335 }
336}