1use std::collections::BTreeMap;
4
5use vyre_foundation::ir::Program;
6use vyre_spec::data_type::DataType;
7
8#[derive(Debug, Clone, PartialEq)]
16#[non_exhaustive]
17pub enum SpecValue {
18 U32(u32),
20 I32(i32),
22 F32(f32),
24 Bool(bool),
26 DType(DataType),
32}
33
34impl SpecValue {
35 #[must_use]
38 pub fn as_pipeline_f64(&self) -> f64 {
39 match self {
40 SpecValue::U32(value) => f64::from(*value),
41 SpecValue::I32(value) => f64::from(*value),
42 SpecValue::F32(value) => f64::from(*value),
43 SpecValue::Bool(value) => f64::from(u8::from(*value)),
44 SpecValue::DType(dtype) => f64::from(dtype_tag(dtype)),
45 }
46 }
47
48 #[must_use]
50 pub fn cache_hash(&self) -> u64 {
51 match self {
52 SpecValue::U32(value) => u64::from(*value) << 8,
53 SpecValue::I32(value) => (1u64) | ((*value as u32 as u64) << 8),
54 SpecValue::F32(value) => (2u64) | ((value.to_bits() as u64) << 8),
55 SpecValue::Bool(value) => (3u64) | (u64::from(u8::from(*value)) << 8),
56 SpecValue::DType(dtype) => (4u64) | (u64::from(dtype_tag(dtype)) << 8),
57 }
58 }
59}
60
61fn dtype_tag(dtype: &DataType) -> u32 {
74 match dtype {
75 DataType::U32 => 0x01,
76 DataType::I32 => 0x02,
77 DataType::U64 => 0x03,
78 DataType::Vec2U32 => 0x04,
79 DataType::Vec4U32 => 0x05,
80 DataType::Bool => 0x06,
81 DataType::Bytes => 0x07,
82 DataType::Array { .. } => 0x08,
83 DataType::F16 => 0x09,
84 DataType::BF16 => 0x0A,
85 DataType::F32 => 0x0B,
86 DataType::F64 => 0x0C,
87 DataType::Tensor => 0x0D,
88 DataType::U8 => 0x0E,
89 DataType::U16 => 0x0F,
90 DataType::I8 => 0x10,
91 DataType::I16 => 0x11,
92 DataType::I64 => 0x12,
93 DataType::Handle(_) => 0x13,
94 DataType::Vec { .. } => 0x14,
95 DataType::TensorShaped { .. } => 0x15,
96 DataType::SparseCsr { .. } => 0x16,
97 DataType::SparseCoo { .. } => 0x17,
98 DataType::SparseBsr { .. } => 0x18,
99 DataType::F8E4M3 => 0x19,
100 DataType::F8E5M2 => 0x1A,
101 DataType::I4 => 0x1B,
102 DataType::FP4 => 0x1C,
103 DataType::NF4 => 0x1D,
104 DataType::DeviceMesh { .. } => 0x1E,
105 DataType::Opaque(_) => 0x80,
106 _ => 0xFFFF_FFFF,
111 }
112}
113
114#[derive(Debug, Default, Clone)]
116pub struct SpecMap {
117 entries: BTreeMap<String, SpecValue>,
118}
119
120impl SpecMap {
121 #[must_use]
123 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn insert(&mut self, name: impl Into<String>, value: SpecValue) {
129 self.entries.insert(name.into(), value);
130 }
131
132 #[must_use]
134 pub fn len(&self) -> usize {
135 self.entries.len()
136 }
137
138 #[must_use]
140 pub fn is_empty(&self) -> bool {
141 self.entries.is_empty()
142 }
143
144 pub fn iter(&self) -> impl Iterator<Item = (&str, &SpecValue)> {
146 self.entries
147 .iter()
148 .map(|(key, value)| (key.as_str(), value))
149 }
150
151 #[must_use]
153 pub fn to_numeric_constants(&self) -> std::collections::HashMap<String, f64> {
154 let mut out = std::collections::HashMap::with_capacity(self.entries.len());
155 for (key, value) in &self.entries {
156 out.insert(key.clone(), value.as_pipeline_f64());
157 }
158 out
159 }
160
161 #[must_use]
163 pub fn cache_hash(&self) -> u64 {
164 let mut hash: u64 = 0xcbf29ce484222325;
165 for (name, value) in self.iter() {
166 for byte in name.as_bytes() {
167 hash ^= u64::from(*byte);
168 hash = hash.wrapping_mul(0x100000001b3);
169 }
170 for byte in value.cache_hash().to_le_bytes() {
171 hash ^= u64::from(byte);
172 hash = hash.wrapping_mul(0x100000001b3);
173 }
174 }
175 hash
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Hash)]
181pub struct SpecCacheKey {
182 pub shader_hash: u64,
184 pub binding_sig: u64,
186 pub workgroup_size: [u32; 3],
188 pub spec_hash: u64,
190}
191
192impl SpecCacheKey {
193 #[must_use]
195 pub fn new(
196 shader_hash: u64,
197 binding_sig: u64,
198 workgroup_size: [u32; 3],
199 specs: &SpecMap,
200 ) -> Self {
201 Self {
202 shader_hash,
203 binding_sig,
204 workgroup_size,
205 spec_hash: specs.cache_hash(),
206 }
207 }
208}
209
210#[must_use]
216pub fn vsa_specialization_key(program: &Program, spec_hash: u64) -> u128 {
217 let fingerprint = crate::launch::program_vsa_fingerprint_words(program);
218 let fp_lo = fingerprint
219 .iter()
220 .take(2)
221 .enumerate()
222 .fold(0_u64, |acc, (i, &word)| {
223 acc | (u64::from(word) << (32 * (i as u32)))
224 });
225 ((fp_lo as u128) << 64) | u128::from(spec_hash)
226}
227
228#[must_use]
235pub fn versioned_specialization_artifact_key(
236 cache_version: u32,
237 spec_hash: &str,
238 backend_fingerprint: &str,
239) -> String {
240 let mut hasher = blake3::Hasher::new();
241 hasher.update(b"vyre-specialization-artifact-key-v1\0version\0");
242 hasher.update(&cache_version.to_le_bytes());
243 hasher.update(b"\0spec\0");
244 hasher.update(&(spec_hash.len() as u64).to_le_bytes());
245 hasher.update(spec_hash.as_bytes());
246 hasher.update(b"\0backend\0");
247 hasher.update(&(backend_fingerprint.len() as u64).to_le_bytes());
248 hasher.update(backend_fingerprint.as_bytes());
249 let hash = hasher.finalize();
250 let mut key = String::with_capacity(64);
251 push_lower_hex(hash.as_bytes(), &mut key);
252 key
253}
254
255fn push_lower_hex(bytes: &[u8], out: &mut String) {
256 const HEX: &[u8; 16] = b"0123456789abcdef";
257 let additional = bytes.len().saturating_mul(2);
258 let _ = out.try_reserve(additional);
259 for &byte in bytes {
260 out.push(HEX[(byte >> 4) as usize] as char);
261 out.push(HEX[(byte & 0x0f) as usize] as char);
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
269
270 #[test]
271 fn spec_map_ordering_is_commutative() {
272 let mut a = SpecMap::new();
273 a.insert("A", SpecValue::U32(1));
274 a.insert("B", SpecValue::U32(2));
275 let mut b = SpecMap::new();
276 b.insert("B", SpecValue::U32(2));
277 b.insert("A", SpecValue::U32(1));
278 assert_eq!(a.cache_hash(), b.cache_hash());
279 }
280
281 #[test]
282 fn cache_key_differs_by_spec_hash() {
283 let mut a = SpecMap::new();
284 a.insert("K", SpecValue::U32(1));
285 let mut b = SpecMap::new();
286 b.insert("K", SpecValue::U32(2));
287 assert_ne!(
288 SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &a),
289 SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &b)
290 );
291 }
292
293 #[test]
294 fn vsa_specialization_key_changes_only_low_half_for_spec_hash() {
295 let program = Program::wrapped(
296 vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
297 [1, 1, 1],
298 vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
299 );
300 let a = vsa_specialization_key(&program, 0x11);
301 let b = vsa_specialization_key(&program, 0x22);
302 assert_eq!(
303 a >> 64,
304 b >> 64,
305 "Fix: VSA specialization keys must keep program identity independent from specialization values."
306 );
307 assert_ne!(
308 a as u64, b as u64,
309 "Fix: VSA specialization keys must include the specialization hash."
310 );
311 }
312
313 #[test]
314 fn versioned_artifact_key_separates_variable_length_fields() {
315 let a = versioned_specialization_artifact_key(1, "ab", "cd");
316 let b = versioned_specialization_artifact_key(1, "abc", "d");
317 assert_ne!(
318 a, b,
319 "Fix: specialization artifact keys must length-prefix variable fields."
320 );
321 }
322
323 #[test]
326 fn dtype_spec_value_round_trips() {
327 let v = SpecValue::DType(DataType::F32);
328 assert!(matches!(v, SpecValue::DType(DataType::F32)));
329 }
330
331 #[test]
332 fn dtype_spec_distinct_dtypes_hash_distinct() {
333 let f32_hash = SpecValue::DType(DataType::F32).cache_hash();
334 let u32_hash = SpecValue::DType(DataType::U32).cache_hash();
335 let i32_hash = SpecValue::DType(DataType::I32).cache_hash();
336 assert_ne!(f32_hash, u32_hash);
337 assert_ne!(u32_hash, i32_hash);
338 assert_ne!(f32_hash, i32_hash);
339 }
340
341 #[test]
342 fn dtype_spec_equal_dtypes_hash_equal() {
343 assert_eq!(
344 SpecValue::DType(DataType::F32).cache_hash(),
345 SpecValue::DType(DataType::F32).cache_hash()
346 );
347 }
348
349 #[test]
350 fn dtype_spec_does_not_collide_with_other_variants() {
351 let dtype_hash = SpecValue::DType(DataType::U32).cache_hash();
355 let u32_hash = SpecValue::U32(0).cache_hash();
356 let i32_hash = SpecValue::I32(0).cache_hash();
357 let f32_hash = SpecValue::F32(0.0).cache_hash();
358 let bool_hash = SpecValue::Bool(false).cache_hash();
359 assert_ne!(dtype_hash, u32_hash);
360 assert_ne!(dtype_hash, i32_hash);
361 assert_ne!(dtype_hash, f32_hash);
362 assert_ne!(dtype_hash, bool_hash);
363 }
364
365 #[test]
366 fn dtype_spec_separates_cache_key_in_specmap() {
367 let mut a = SpecMap::new();
368 a.insert("dtype", SpecValue::DType(DataType::F32));
369 let mut b = SpecMap::new();
370 b.insert("dtype", SpecValue::DType(DataType::U32));
371 assert_ne!(
372 a.cache_hash(),
373 b.cache_hash(),
374 "Fix: dtype-keyed SpecMaps must produce distinct cache hashes."
375 );
376 assert_ne!(
377 SpecCacheKey::new(0, 0, [1, 1, 1], &a),
378 SpecCacheKey::new(0, 0, [1, 1, 1], &b)
379 );
380 }
381
382 #[test]
383 fn dtype_tag_covers_every_data_type() {
384 let known = [
388 DataType::U32,
389 DataType::I32,
390 DataType::U64,
391 DataType::Vec2U32,
392 DataType::Vec4U32,
393 DataType::Bool,
394 DataType::Bytes,
395 DataType::Array { element_size: 1 },
396 DataType::F16,
397 DataType::BF16,
398 DataType::F32,
399 DataType::F64,
400 DataType::Tensor,
401 DataType::U8,
402 DataType::U16,
403 DataType::I8,
404 DataType::I16,
405 DataType::I64,
406 DataType::Handle(vyre_spec::data_type::TypeId(0)),
407 DataType::Vec {
408 element: Box::new(DataType::U32),
409 count: 1,
410 },
411 DataType::TensorShaped {
412 element: Box::new(DataType::U32),
413 shape: smallvec::smallvec![1],
414 },
415 DataType::SparseCsr {
416 element: Box::new(DataType::U32),
417 },
418 DataType::SparseCoo {
419 element: Box::new(DataType::U32),
420 },
421 DataType::SparseBsr {
422 element: Box::new(DataType::U32),
423 block_rows: 1,
424 block_cols: 1,
425 },
426 DataType::F8E4M3,
427 DataType::F8E5M2,
428 DataType::I4,
429 DataType::FP4,
430 DataType::NF4,
431 DataType::DeviceMesh {
432 axes: smallvec::smallvec![1],
433 },
434 ];
435 let mut tags = std::collections::BTreeSet::new();
436 for dtype in known {
437 let tag = dtype_tag(&dtype);
438 assert_ne!(
439 tag, 0xFFFF_FFFF,
440 "Fix: dtype_tag missing arm for {dtype:?} - extend specialization.rs::dtype_tag."
441 );
442 assert!(
443 tags.insert(tag),
444 "Fix: dtype_tag returned duplicate tag {tag} for {dtype:?}."
445 );
446 }
447 }
448}