1use std::collections::HashSet;
4use std::hash::BuildHasherDefault;
5
6use rustc_hash::FxHasher;
7use vyre_foundation::ir::{OpId, Program};
8use vyre_foundation::validate::{BackendValidationCapabilities, ValidationOptions};
9
10use crate::{BackendError, DispatchConfig, VyreBackend};
11
12pub const DEFAULT_VALIDATION_HASH_ENTRIES: usize = 8192;
14pub const DEFAULT_VALIDATION_VSA_ENTRIES: usize = 2048;
16pub const DEFAULT_VALIDATION_VSA_SHARDS: usize = 64;
18
19type ValidationSet = dashmap::DashSet<blake3::Hash, BuildHasherDefault<FxHasher>>;
20
21pub struct ValidationCache {
23 hashes: ValidationSet,
24 vsa_hashes: ValidationSet,
25 max_hash_entries: usize,
26 max_vsa_entries: usize,
27 vsa_shards: usize,
28}
29
30impl std::fmt::Debug for ValidationCache {
31 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 formatter
33 .debug_struct("ValidationCache")
34 .field("hashes", &self.hashes.len())
35 .field("vsa_hashes", &self.vsa_hashes.len())
36 .field("vsa_shards", &self.vsa_shards)
37 .field("max_hash_entries", &self.max_hash_entries)
38 .field("max_vsa_entries", &self.max_vsa_entries)
39 .finish()
40 }
41}
42
43impl Default for ValidationCache {
44 fn default() -> Self {
45 Self::new(
46 DEFAULT_VALIDATION_HASH_ENTRIES,
47 DEFAULT_VALIDATION_VSA_ENTRIES,
48 DEFAULT_VALIDATION_VSA_SHARDS,
49 )
50 }
51}
52
53impl ValidationCache {
54 #[must_use]
56 pub fn new(max_hash_entries: usize, max_vsa_entries: usize, vsa_shards: usize) -> Self {
57 let shard_count = vsa_shards.max(1);
58 Self {
59 hashes: dashmap::DashSet::with_hasher(BuildHasherDefault::<FxHasher>::default()),
60 vsa_hashes: dashmap::DashSet::with_capacity_and_hasher(
61 max_vsa_entries.max(1),
62 BuildHasherDefault::<FxHasher>::default(),
63 ),
64 max_hash_entries: max_hash_entries.max(1),
65 max_vsa_entries: max_vsa_entries.max(1),
66 vsa_shards: shard_count,
67 }
68 }
69
70 #[must_use]
72 pub fn program_hash(program: &Program) -> blake3::Hash {
73 blake3::Hash::from(program.fingerprint())
74 }
75
76 #[must_use]
78 pub fn contains_hash(&self, hash: &blake3::Hash) -> bool {
79 self.hashes.contains(hash)
80 }
81
82 pub fn remember_hash(&self, hash: blake3::Hash) {
84 if self.hashes.len() >= self.max_hash_entries {
85 self.hashes.clear();
86 }
87 self.hashes.insert(hash);
88 }
89
90 pub fn remember_success(&self, hash: blake3::Hash, vsa: &[u32]) -> Result<(), BackendError> {
96 self.remember_hash(hash);
97 if self.vsa_hashes.len() >= self.max_vsa_entries {
98 self.vsa_hashes.clear();
99 }
100 self.vsa_hashes.insert(vsa_words_hash(vsa));
101 Ok(())
102 }
103
104 pub fn clear(&self) -> Result<(), BackendError> {
110 self.hashes.clear();
111 self.vsa_hashes.clear();
112 Ok(())
113 }
114
115 pub fn get_or_validate(
126 &self,
127 program: &Program,
128 validation_options: ValidationOptions<'_>,
129 supported_ops: &HashSet<OpId>,
130 caps: ProgramValidationCaps,
131 ) -> Result<(), BackendError> {
132 let hash = Self::program_hash(program);
133 if self.contains_hash(&hash) || program.is_validated_on(caps.backend_id) {
134 self.remember_hash(hash);
135 return Ok(());
136 }
137
138 validate_program_contract(program, validation_options, supported_ops, caps)?;
139
140 let vsa = crate::launch::program_vsa_fingerprint_words(program);
141 self.remember_success(hash, &vsa)?;
142 program.mark_validated_on(caps.backend_id);
143 Ok(())
144 }
145
146 pub fn get_or_validate_backend<B>(
157 &self,
158 program: &Program,
159 backend: &B,
160 ) -> Result<(), BackendError>
161 where
162 B: VyreBackend + BackendValidationCapabilities,
163 {
164 let validation_options = ValidationOptions::default().with_backend(backend);
165 self.get_or_validate(
166 program,
167 validation_options,
168 backend.supported_ops(),
169 ProgramValidationCaps::from_backend(backend),
170 )
171 }
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
176pub struct ProgramValidationCaps {
177 pub backend_id: &'static str,
179 pub supports_subgroup_ops: bool,
181 pub supports_f16: bool,
183 pub supports_bf16: bool,
185 pub supports_indirect_dispatch: bool,
187 pub supports_distributed_collectives: bool,
189 pub supports_trap_propagation: bool,
191 pub max_workgroup_size: [u32; 3],
193}
194
195impl ProgramValidationCaps {
196 #[must_use]
198 pub fn from_backend(backend: &dyn VyreBackend) -> Self {
199 Self {
200 backend_id: backend.id(),
201 supports_subgroup_ops: backend.supports_subgroup_ops(),
202 supports_f16: backend.supports_f16(),
203 supports_bf16: backend.supports_bf16(),
204 supports_indirect_dispatch: backend.supports_indirect_dispatch(),
205 supports_distributed_collectives: backend.supports_distributed_collectives(),
206 supports_trap_propagation: true,
207 max_workgroup_size: backend.max_workgroup_size(),
208 }
209 }
210}
211
212pub fn validate_program_contract(
219 program: &Program,
220 validation_options: ValidationOptions<'_>,
221 supported_ops: &HashSet<OpId>,
222 caps: ProgramValidationCaps,
223) -> Result<(), BackendError> {
224 let lowered_program = if caps.supports_distributed_collectives {
225 None
226 } else {
227 vyre_foundation::transform::collectives::lower_single_rank_collectives(program).map_err(
228 |error| BackendError::InvalidProgram {
229 fix: error.to_string(),
230 },
231 )?
232 };
233 let program = lowered_program.as_ref().unwrap_or(program);
234 let report = vyre_foundation::validate::validate_with_options(program, validation_options);
235 if let Some(first) = report.errors.into_iter().next() {
236 return Err(BackendError::InvalidProgram {
237 fix: first.message.into_owned(),
238 });
239 }
240
241 validate_supported_ops(program, caps.backend_id, supported_ops).map_err(|error| {
242 BackendError::InvalidProgram {
243 fix: error.to_string(),
244 }
245 })?;
246
247 let required = vyre_foundation::program_caps::scan(program);
248 vyre_foundation::program_caps::check_backend_capabilities(
249 caps.backend_id,
250 caps.supports_subgroup_ops,
251 caps.supports_f16,
252 caps.supports_bf16,
253 caps.supports_indirect_dispatch,
254 caps.supports_trap_propagation,
255 caps.supports_distributed_collectives,
256 caps.max_workgroup_size,
257 &required,
258 )
259 .map_err(|error| BackendError::InvalidProgram {
260 fix: error.to_string(),
261 })
262}
263
264fn validate_supported_ops(
265 program: &Program,
266 backend_id: &'static str,
267 supported_ops: &HashSet<OpId>,
268) -> Result<(), vyre_foundation::ir::ValidationError> {
269 struct SupportedOpsBackend<'a> {
270 id: &'static str,
271 ops: &'a HashSet<OpId>,
272 }
273
274 impl crate::backend::Backend for SupportedOpsBackend<'_> {
275 fn id(&self) -> &'static str {
276 self.id
277 }
278
279 fn version(&self) -> &'static str {
280 env!("CARGO_PKG_VERSION")
281 }
282
283 fn supported_ops(&self) -> &HashSet<OpId> {
284 self.ops
285 }
286 }
287
288 crate::backend::validation::validate_program(
289 program,
290 &SupportedOpsBackend {
291 id: backend_id,
292 ops: supported_ops,
293 },
294 )
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub struct LaunchGeometryLimits {
300 pub backend: &'static str,
302 pub max_threads_per_block: u32,
304 pub max_block_dim: [u32; 3],
306 pub max_grid_dim: [u32; 3],
308}
309
310pub fn validate_launch_geometry(
317 workgroup: [u32; 3],
318 grid: [u32; 3],
319 limits: LaunchGeometryLimits,
320) -> Result<(), BackendError> {
321 if workgroup.contains(&0) || grid.contains(&0) {
322 return Err(BackendError::InvalidProgram {
323 fix: format!(
324 "Fix: {} workgroup and grid dimensions must all be non-zero.",
325 limits.backend
326 ),
327 });
328 }
329 let threads = workgroup[0]
330 .checked_mul(workgroup[1])
331 .and_then(|xy| xy.checked_mul(workgroup[2]))
332 .ok_or_else(|| BackendError::InvalidProgram {
333 fix: format!(
334 "Fix: {} workgroup dimensions overflowed u32; reduce workgroup_override.",
335 limits.backend
336 ),
337 })?;
338 if threads > limits.max_threads_per_block {
339 return Err(BackendError::InvalidProgram {
340 fix: format!(
341 "Fix: {} workgroup has {threads} threads but device max is {}.",
342 limits.backend, limits.max_threads_per_block
343 ),
344 });
345 }
346 for (axis, &dim) in workgroup.iter().enumerate() {
347 if dim > limits.max_block_dim[axis] {
348 return Err(BackendError::InvalidProgram {
349 fix: format!(
350 "Fix: {} workgroup axis {axis} requested {} threads but device max is {}.",
351 limits.backend, dim, limits.max_block_dim[axis]
352 ),
353 });
354 }
355 }
356 for (axis, &dim) in grid.iter().enumerate() {
357 if dim > limits.max_grid_dim[axis] {
358 return Err(BackendError::InvalidProgram {
359 fix: format!(
360 "Fix: {} grid axis {axis} requested {} workgroups but device max is {}.",
361 limits.backend, dim, limits.max_grid_dim[axis]
362 ),
363 });
364 }
365 }
366 Ok(())
367}
368
369pub fn validate_program_for_backend(
379 backend: &dyn VyreBackend,
380 program: &Program,
381 config: &DispatchConfig,
382) -> Result<(), BackendError> {
383 let workgroup = config
384 .workgroup_override
385 .unwrap_or(program.workgroup_size());
386 let max_axes = backend.max_workgroup_size();
387 if workgroup.contains(&0) {
388 return Err(BackendError::InvalidProgram {
389 fix: format!(
390 "Fix: backend `{}` cannot dispatch zero-sized workgroup dimensions; set positive workgroup sizes.",
391 backend.id()
392 ),
393 });
394 }
395 for (axis, &dim) in workgroup.iter().enumerate() {
396 if dim > max_axes[axis] {
397 return Err(BackendError::InvalidProgram {
398 fix: format!(
399 "Fix: backend `{}` workgroup axis {axis} requested {} but max is {}.",
400 backend.id(),
401 dim,
402 max_axes[axis]
403 ),
404 });
405 }
406 }
407 let invocations = workgroup[0]
408 .checked_mul(workgroup[1])
409 .and_then(|xy| xy.checked_mul(workgroup[2]))
410 .ok_or_else(|| BackendError::InvalidProgram {
411 fix: format!(
412 "Fix: backend `{}` workgroup dimensions overflowed u32; reduce workgroup size.",
413 backend.id()
414 ),
415 })?;
416 let max_invocations = backend.max_compute_invocations_per_workgroup();
417 if invocations > max_invocations {
418 return Err(BackendError::InvalidProgram {
419 fix: format!(
420 "Fix: backend `{}` workgroup has {invocations} invocations but max is {max_invocations}.",
421 backend.id()
422 ),
423 });
424 }
425 if let Some(grid) = config.grid_override {
426 let max_workgroups = backend.max_compute_workgroups_per_dimension();
427 if grid.contains(&0) {
428 return Err(BackendError::InvalidProgram {
429 fix: format!(
430 "Fix: backend `{}` cannot dispatch zero-sized grid dimensions; set positive grid_override values.",
431 backend.id()
432 ),
433 });
434 }
435 for (axis, &dim) in grid.iter().enumerate() {
436 if dim > max_workgroups {
437 return Err(BackendError::InvalidProgram {
438 fix: format!(
439 "Fix: backend `{}` grid_override axis {axis} requested {} workgroups but max is {}.",
440 backend.id(),
441 dim,
442 max_workgroups
443 ),
444 });
445 }
446 }
447 }
448 Ok(())
449}
450
451fn vsa_words_hash(words: &[u32]) -> blake3::Hash {
452 let mut hasher = blake3::Hasher::new();
453 hasher.update(&(words.len() as u64).to_le_bytes());
454 for word in words {
455 hasher.update(&word.to_le_bytes());
456 }
457 hasher.finalize()
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn validation_cache_records_vsa_without_lock_shards() {
466 let cache = ValidationCache::new(8, 8, 4);
467 let hash = blake3::hash(b"program");
468 cache
469 .remember_success(hash, &[1, 2, 3, 4])
470 .expect("Fix: lock-free VSA cache insertion must not fail");
471
472 assert!(cache.contains_hash(&hash));
473 assert_eq!(cache.vsa_hashes.len(), 1);
474 assert!(format!("{cache:?}").contains("vsa_hashes"));
475 }
476
477 #[test]
478 fn validation_cache_bounds_vsa_hashes_by_clear() {
479 let cache = ValidationCache::new(8, 2, 4);
480 for i in 0..3u32 {
481 cache
482 .remember_success(blake3::hash(&i.to_le_bytes()), &[i])
483 .expect("Fix: VSA cache insertion must stay infallible");
484 }
485 assert!(
486 cache.vsa_hashes.len() <= 2,
487 "Fix: bounded VSA cache must not grow past max entries"
488 );
489 }
490}