Skip to main content

vyre_driver/
validation.rs

1//! Shared validation caches and launch-geometry checks for concrete drivers.
2
3use std::collections::HashSet;
4use std::hash::BuildHasherDefault;
5
6use rustc_hash::FxHasher;
7use vyre_foundation::ir::{OpId, Program};
8use vyre_foundation::validate::{BackendValidationCapabilities, ValidationOptions};
9
10use crate::{BackendError, DispatchConfig, VyreBackend};
11
12/// Default successful-validation hash entries retained per backend instance.
13pub const DEFAULT_VALIDATION_HASH_ENTRIES: usize = 8192;
14/// Default VSA fingerprints retained per backend instance.
15pub const DEFAULT_VALIDATION_VSA_ENTRIES: usize = 2048;
16/// Default VSA shard count.
17pub const DEFAULT_VALIDATION_VSA_SHARDS: usize = 64;
18
19type ValidationSet = dashmap::DashSet<blake3::Hash, BuildHasherDefault<FxHasher>>;
20
21/// Successful-program validation cache shared by concrete drivers.
22pub struct ValidationCache {
23    hashes: ValidationSet,
24    vsa_hashes: ValidationSet,
25    max_hash_entries: usize,
26    max_vsa_entries: usize,
27    vsa_shards: usize,
28}
29
30impl std::fmt::Debug for ValidationCache {
31    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        formatter
33            .debug_struct("ValidationCache")
34            .field("hashes", &self.hashes.len())
35            .field("vsa_hashes", &self.vsa_hashes.len())
36            .field("vsa_shards", &self.vsa_shards)
37            .field("max_hash_entries", &self.max_hash_entries)
38            .field("max_vsa_entries", &self.max_vsa_entries)
39            .finish()
40    }
41}
42
43impl Default for ValidationCache {
44    fn default() -> Self {
45        Self::new(
46            DEFAULT_VALIDATION_HASH_ENTRIES,
47            DEFAULT_VALIDATION_VSA_ENTRIES,
48            DEFAULT_VALIDATION_VSA_SHARDS,
49        )
50    }
51}
52
53impl ValidationCache {
54    /// Create a validation cache with bounded hash and VSA storage.
55    #[must_use]
56    pub fn new(max_hash_entries: usize, max_vsa_entries: usize, vsa_shards: usize) -> Self {
57        let shard_count = vsa_shards.max(1);
58        Self {
59            hashes: dashmap::DashSet::with_hasher(BuildHasherDefault::<FxHasher>::default()),
60            vsa_hashes: dashmap::DashSet::with_capacity_and_hasher(
61                max_vsa_entries.max(1),
62                BuildHasherDefault::<FxHasher>::default(),
63            ),
64            max_hash_entries: max_hash_entries.max(1),
65            max_vsa_entries: max_vsa_entries.max(1),
66            vsa_shards: shard_count,
67        }
68    }
69
70    /// Compute the validation hash for a program.
71    #[must_use]
72    pub fn program_hash(program: &Program) -> blake3::Hash {
73        blake3::Hash::from(program.fingerprint())
74    }
75
76    /// Return whether a validation hash is cached.
77    #[must_use]
78    pub fn contains_hash(&self, hash: &blake3::Hash) -> bool {
79        self.hashes.contains(hash)
80    }
81
82    /// Remember a successful validation hash.
83    pub fn remember_hash(&self, hash: blake3::Hash) {
84        if self.hashes.len() >= self.max_hash_entries {
85            self.hashes.clear();
86        }
87        self.hashes.insert(hash);
88    }
89
90    /// Remember a successful validation hash and its VSA fingerprint.
91    ///
92    /// # Errors
93    ///
94    /// Returns if a VSA shard lock is poisoned.
95    pub fn remember_success(&self, hash: blake3::Hash, vsa: &[u32]) -> Result<(), BackendError> {
96        self.remember_hash(hash);
97        if self.vsa_hashes.len() >= self.max_vsa_entries {
98            self.vsa_hashes.clear();
99        }
100        self.vsa_hashes.insert(vsa_words_hash(vsa));
101        Ok(())
102    }
103
104    /// Clear cached validation state.
105    ///
106    /// # Errors
107    ///
108    /// Returns if a VSA shard lock is poisoned.
109    pub fn clear(&self) -> Result<(), BackendError> {
110        self.hashes.clear();
111        self.vsa_hashes.clear();
112        Ok(())
113    }
114
115    /// Validate `program` once, memoizing the complete backend contract.
116    ///
117    /// This is the shared driver validation path: foundation invariants,
118    /// backend supported-op coverage, program capability requirements, and
119    /// VSA cache insertion all happen in one place. Concrete drivers supply
120    /// only their actual capability values.
121    ///
122    /// # Errors
123    ///
124    /// Returns when validation fails or a VSA shard lock is poisoned.
125    pub fn get_or_validate(
126        &self,
127        program: &Program,
128        validation_options: ValidationOptions<'_>,
129        supported_ops: &HashSet<OpId>,
130        caps: ProgramValidationCaps,
131    ) -> Result<(), BackendError> {
132        let hash = Self::program_hash(program);
133        if self.contains_hash(&hash) || program.is_validated_on(caps.backend_id) {
134            self.remember_hash(hash);
135            return Ok(());
136        }
137
138        validate_program_contract(program, validation_options, supported_ops, caps)?;
139
140        let vsa = crate::launch::program_vsa_fingerprint_words(program);
141        self.remember_success(hash, &vsa)?;
142        program.mark_validated_on(caps.backend_id);
143        Ok(())
144    }
145
146    /// Validate `program` against a concrete backend and cache successful
147    /// results.
148    ///
149    /// This is the canonical driver-owned validation-cache entry point for
150    /// backends that implement both the runtime backend contract and the
151    /// foundation capability-validation contract.
152    ///
153    /// # Errors
154    ///
155    /// Returns when validation fails or cache mutation fails.
156    pub fn get_or_validate_backend<B>(
157        &self,
158        program: &Program,
159        backend: &B,
160    ) -> Result<(), BackendError>
161    where
162        B: VyreBackend + BackendValidationCapabilities,
163    {
164        let validation_options = ValidationOptions::default().with_backend(backend);
165        self.get_or_validate(
166            program,
167            validation_options,
168            backend.supported_ops(),
169            ProgramValidationCaps::from_backend(backend),
170        )
171    }
172}
173
174/// Concrete backend capability values needed for shared program validation.
175#[derive(Debug, Clone, Copy, PartialEq, Eq)]
176pub struct ProgramValidationCaps {
177    /// Stable backend identifier used in diagnostics and validation stamps.
178    pub backend_id: &'static str,
179    /// Native subgroup operations are available and lowered.
180    pub supports_subgroup_ops: bool,
181    /// IEEE binary16 buffers/operations are lowered.
182    pub supports_f16: bool,
183    /// Bfloat16 buffers/operations are lowered.
184    pub supports_bf16: bool,
185    /// Indirect dispatch is lowered.
186    pub supports_indirect_dispatch: bool,
187    /// Distributed collective communication nodes are lowered.
188    pub supports_distributed_collectives: bool,
189    /// `Node::Trap` is lowered with backend-visible trap semantics.
190    pub supports_trap_propagation: bool,
191    /// Maximum supported workgroup dimensions.
192    pub max_workgroup_size: [u32; 3],
193}
194
195impl ProgramValidationCaps {
196    /// Snapshot capability values from a `VyreBackend` trait object.
197    #[must_use]
198    pub fn from_backend(backend: &dyn VyreBackend) -> Self {
199        Self {
200            backend_id: backend.id(),
201            supports_subgroup_ops: backend.supports_subgroup_ops(),
202            supports_f16: backend.supports_f16(),
203            supports_bf16: backend.supports_bf16(),
204            supports_indirect_dispatch: backend.supports_indirect_dispatch(),
205            supports_distributed_collectives: backend.supports_distributed_collectives(),
206            supports_trap_propagation: true,
207            max_workgroup_size: backend.max_workgroup_size(),
208        }
209    }
210}
211
212/// Validate a program against backend-neutral and backend-reported contracts.
213///
214/// # Errors
215///
216/// Returns when foundation validation, supported-op validation, or required
217/// capability checks fail.
218pub fn validate_program_contract(
219    program: &Program,
220    validation_options: ValidationOptions<'_>,
221    supported_ops: &HashSet<OpId>,
222    caps: ProgramValidationCaps,
223) -> Result<(), BackendError> {
224    let lowered_program = if caps.supports_distributed_collectives {
225        None
226    } else {
227        vyre_foundation::transform::collectives::lower_single_rank_collectives(program).map_err(
228            |error| BackendError::InvalidProgram {
229                fix: error.to_string(),
230            },
231        )?
232    };
233    let program = lowered_program.as_ref().unwrap_or(program);
234    let report = vyre_foundation::validate::validate_with_options(program, validation_options);
235    if let Some(first) = report.errors.into_iter().next() {
236        return Err(BackendError::InvalidProgram {
237            fix: first.message.into_owned(),
238        });
239    }
240
241    validate_supported_ops(program, caps.backend_id, supported_ops).map_err(|error| {
242        BackendError::InvalidProgram {
243            fix: error.to_string(),
244        }
245    })?;
246
247    let required = vyre_foundation::program_caps::scan(program);
248    vyre_foundation::program_caps::check_backend_capabilities(
249        caps.backend_id,
250        caps.supports_subgroup_ops,
251        caps.supports_f16,
252        caps.supports_bf16,
253        caps.supports_indirect_dispatch,
254        caps.supports_trap_propagation,
255        caps.supports_distributed_collectives,
256        caps.max_workgroup_size,
257        &required,
258    )
259    .map_err(|error| BackendError::InvalidProgram {
260        fix: error.to_string(),
261    })
262}
263
264fn validate_supported_ops(
265    program: &Program,
266    backend_id: &'static str,
267    supported_ops: &HashSet<OpId>,
268) -> Result<(), vyre_foundation::ir::ValidationError> {
269    struct SupportedOpsBackend<'a> {
270        id: &'static str,
271        ops: &'a HashSet<OpId>,
272    }
273
274    impl crate::backend::Backend for SupportedOpsBackend<'_> {
275        fn id(&self) -> &'static str {
276            self.id
277        }
278
279        fn version(&self) -> &'static str {
280            env!("CARGO_PKG_VERSION")
281        }
282
283        fn supported_ops(&self) -> &HashSet<OpId> {
284            self.ops
285        }
286    }
287
288    crate::backend::validation::validate_program(
289        program,
290        &SupportedOpsBackend {
291            id: backend_id,
292            ops: supported_ops,
293        },
294    )
295}
296
297/// Launch-geometry limits reported by a concrete driver.
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub struct LaunchGeometryLimits {
300    /// Backend name used in diagnostics.
301    pub backend: &'static str,
302    /// Maximum invocations in one workgroup or block.
303    pub max_threads_per_block: u32,
304    /// Maximum workgroup or block dimensions (x, y, z).
305    pub max_block_dim: [u32; 3],
306    /// Maximum workgroup count per grid dimension.
307    pub max_grid_dim: [u32; 3],
308}
309
310/// Validate workgroup and grid dimensions against backend launch limits.
311///
312/// # Errors
313///
314/// Returns when dimensions are zero, overflow the invocation product, exceed
315/// workgroup limits, or exceed per-axis grid limits.
316pub fn validate_launch_geometry(
317    workgroup: [u32; 3],
318    grid: [u32; 3],
319    limits: LaunchGeometryLimits,
320) -> Result<(), BackendError> {
321    if workgroup.contains(&0) || grid.contains(&0) {
322        return Err(BackendError::InvalidProgram {
323            fix: format!(
324                "Fix: {} workgroup and grid dimensions must all be non-zero.",
325                limits.backend
326            ),
327        });
328    }
329    let threads = workgroup[0]
330        .checked_mul(workgroup[1])
331        .and_then(|xy| xy.checked_mul(workgroup[2]))
332        .ok_or_else(|| BackendError::InvalidProgram {
333            fix: format!(
334                "Fix: {} workgroup dimensions overflowed u32; reduce workgroup_override.",
335                limits.backend
336            ),
337        })?;
338    if threads > limits.max_threads_per_block {
339        return Err(BackendError::InvalidProgram {
340            fix: format!(
341                "Fix: {} workgroup has {threads} threads but device max is {}.",
342                limits.backend, limits.max_threads_per_block
343            ),
344        });
345    }
346    for (axis, &dim) in workgroup.iter().enumerate() {
347        if dim > limits.max_block_dim[axis] {
348            return Err(BackendError::InvalidProgram {
349                fix: format!(
350                    "Fix: {} workgroup axis {axis} requested {} threads but device max is {}.",
351                    limits.backend, dim, limits.max_block_dim[axis]
352                ),
353            });
354        }
355    }
356    for (axis, &dim) in grid.iter().enumerate() {
357        if dim > limits.max_grid_dim[axis] {
358            return Err(BackendError::InvalidProgram {
359                fix: format!(
360                    "Fix: {} grid axis {axis} requested {} workgroups but device max is {}.",
361                    limits.backend, dim, limits.max_grid_dim[axis]
362                ),
363            });
364        }
365    }
366    Ok(())
367}
368
369/// Validate a program's effective workgroup shape against a backend's reported limits.
370///
371/// This is the shared pre-dispatch gate for callers that have a `VyreBackend`
372/// trait object but have not entered a concrete driver yet.
373///
374/// # Errors
375///
376/// Returns when any workgroup axis is zero, exceeds the backend's per-axis
377/// limit, or when total invocations exceed the backend's workgroup limit.
378pub fn validate_program_for_backend(
379    backend: &dyn VyreBackend,
380    program: &Program,
381    config: &DispatchConfig,
382) -> Result<(), BackendError> {
383    let workgroup = config
384        .workgroup_override
385        .unwrap_or(program.workgroup_size());
386    let max_axes = backend.max_workgroup_size();
387    if workgroup.contains(&0) {
388        return Err(BackendError::InvalidProgram {
389            fix: format!(
390                "Fix: backend `{}` cannot dispatch zero-sized workgroup dimensions; set positive workgroup sizes.",
391                backend.id()
392            ),
393        });
394    }
395    for (axis, &dim) in workgroup.iter().enumerate() {
396        if dim > max_axes[axis] {
397            return Err(BackendError::InvalidProgram {
398                fix: format!(
399                    "Fix: backend `{}` workgroup axis {axis} requested {} but max is {}.",
400                    backend.id(),
401                    dim,
402                    max_axes[axis]
403                ),
404            });
405        }
406    }
407    let invocations = workgroup[0]
408        .checked_mul(workgroup[1])
409        .and_then(|xy| xy.checked_mul(workgroup[2]))
410        .ok_or_else(|| BackendError::InvalidProgram {
411            fix: format!(
412                "Fix: backend `{}` workgroup dimensions overflowed u32; reduce workgroup size.",
413                backend.id()
414            ),
415        })?;
416    let max_invocations = backend.max_compute_invocations_per_workgroup();
417    if invocations > max_invocations {
418        return Err(BackendError::InvalidProgram {
419            fix: format!(
420                "Fix: backend `{}` workgroup has {invocations} invocations but max is {max_invocations}.",
421                backend.id()
422            ),
423        });
424    }
425    if let Some(grid) = config.grid_override {
426        let max_workgroups = backend.max_compute_workgroups_per_dimension();
427        if grid.contains(&0) {
428            return Err(BackendError::InvalidProgram {
429                fix: format!(
430                    "Fix: backend `{}` cannot dispatch zero-sized grid dimensions; set positive grid_override values.",
431                    backend.id()
432                ),
433            });
434        }
435        for (axis, &dim) in grid.iter().enumerate() {
436            if dim > max_workgroups {
437                return Err(BackendError::InvalidProgram {
438                    fix: format!(
439                        "Fix: backend `{}` grid_override axis {axis} requested {} workgroups but max is {}.",
440                        backend.id(),
441                        dim,
442                        max_workgroups
443                    ),
444                });
445            }
446        }
447    }
448    Ok(())
449}
450
451fn vsa_words_hash(words: &[u32]) -> blake3::Hash {
452    let mut hasher = blake3::Hasher::new();
453    hasher.update(&(words.len() as u64).to_le_bytes());
454    for word in words {
455        hasher.update(&word.to_le_bytes());
456    }
457    hasher.finalize()
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn validation_cache_records_vsa_without_lock_shards() {
466        let cache = ValidationCache::new(8, 8, 4);
467        let hash = blake3::hash(b"program");
468        cache
469            .remember_success(hash, &[1, 2, 3, 4])
470            .expect("Fix: lock-free VSA cache insertion must not fail");
471
472        assert!(cache.contains_hash(&hash));
473        assert_eq!(cache.vsa_hashes.len(), 1);
474        assert!(format!("{cache:?}").contains("vsa_hashes"));
475    }
476
477    #[test]
478    fn validation_cache_bounds_vsa_hashes_by_clear() {
479        let cache = ValidationCache::new(8, 2, 4);
480        for i in 0..3u32 {
481            cache
482                .remember_success(blake3::hash(&i.to_le_bytes()), &[i])
483                .expect("Fix: VSA cache insertion must stay infallible");
484        }
485        assert!(
486            cache.vsa_hashes.len() <= 2,
487            "Fix: bounded VSA cache must not grow past max entries"
488        );
489    }
490}