vyre_driver/pipeline/
hashing.rs1use crate::backend::DispatchConfig;
4use vyre_foundation::ir::Program;
5use vyre_foundation::serial::wire::{append_data_type_fingerprint, append_node_list_fingerprint};
6use vyre_spec::BackendId;
7
8pub fn try_normalized_program_cache_digest(program: &Program) -> Result<[u8; 32], String> {
16 thread_local! {
17 static SCRATCH: std::cell::RefCell<Vec<u8>> = std::cell::RefCell::new(Vec::with_capacity(1024));
18 }
19 SCRATCH.with(|cell| {
20 let mut scratch = cell.borrow_mut();
21 scratch.clear();
22 scratch.extend_from_slice(b"vyre-pipeline-cache-norm-v2\0wg\0");
23 for axis in program.workgroup_size() {
24 scratch.extend_from_slice(&axis.to_le_bytes());
25 }
26 scratch.extend_from_slice(b"\0op\0");
27 match program.entry_op_id() {
28 Some(op) => scratch.extend_from_slice(op.as_bytes()),
29 None => scratch.extend_from_slice(b"<anon>"),
30 }
31 scratch.extend_from_slice(b"\0v\0");
32 scratch.push(u8::from(program.is_structurally_validated()));
33 scratch.extend_from_slice(b"\0bufs\0");
34 for buffer in program.buffers().iter() {
35 scratch.extend_from_slice(buffer.name().as_bytes());
36 scratch.push(0);
37 scratch.push(buffer.kind() as u8);
38 scratch.push(buffer.access() as u8);
39 append_data_type_fingerprint(&mut scratch, &buffer.element()).map_err(|message| {
40 format!(
41 "failed to fingerprint pipeline-cache buffer data type `{}`: {message}. Fix: validate and normalize the Program before computing a compiled-pipeline cache key; invalid IR must not enter cache identity.",
42 buffer.name()
43 )
44 })?;
45 scratch.push(0);
46 }
47 scratch.extend_from_slice(b"\0body\0");
48 append_node_list_fingerprint(&mut scratch, program.entry()).map_err(|message| {
49 format!(
50 "failed to fingerprint pipeline-cache Program body: {message}. Fix: validate and normalize the Program before computing a compiled-pipeline cache key; invalid IR must not enter cache identity."
51 )
52 })?;
53 Ok(*blake3::hash(&scratch).as_bytes())
54 })
55}
56
57#[must_use]
59pub fn normalized_program_cache_digest(program: &Program) -> [u8; 32] {
60 try_normalized_program_cache_digest(program).unwrap_or([0u8; 32])
61}
62
63pub fn update_dispatch_policy_cache_hash(hasher: &mut blake3::Hasher, config: &DispatchConfig) {
66 hasher.update(b"ulp\0");
67 match config.ulp_budget {
68 Some(ulp) => {
69 hasher.update(&[1, ulp]);
70 }
71 None => {
72 hasher.update(&[0, 0]);
73 }
74 };
75 hasher.update(b"\0wg\0");
76 match config.workgroup_override {
77 Some(workgroup) => {
78 hasher.update(&[1]);
79 for axis in workgroup {
80 hasher.update(&axis.to_le_bytes());
81 }
82 }
83 None => {
84 hasher.update(&[0]);
85 }
86 };
87}
88
89#[must_use]
95pub fn dispatch_policy_cache_digest(config: &DispatchConfig) -> [u8; 32] {
96 let mut hasher = blake3::Hasher::new();
97 update_dispatch_policy_cache_hash(&mut hasher, config);
98 *hasher.finalize().as_bytes()
99}
100
101#[must_use]
103pub fn dispatch_policy_cache_string(config: &DispatchConfig) -> String {
104 let mut policy = String::with_capacity(64);
108 policy.push_str("ulp=");
109 push_debug_option_u8(&mut policy, config.ulp_budget);
110 policy.push_str(":wg=");
111 push_debug_option_workgroup(&mut policy, config.workgroup_override);
112 policy
113}
114
115#[must_use]
117pub fn hex_encode(bytes: &[u8]) -> String {
118 const HEX: &[u8; 16] = b"0123456789abcdef";
119 let mut out = String::with_capacity(bytes.len() * 2);
120 for &byte in bytes {
121 out.push(HEX[(byte >> 4) as usize] as char);
122 out.push(HEX[(byte & 0x0f) as usize] as char);
123 }
124 out
125}
126
127#[must_use]
129pub fn hex_short(bytes: &[u8; 32]) -> String {
130 hex_encode(&bytes[..8])
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
135pub struct PipelineDeviceFingerprint {
136 pub vendor: u32,
138 pub device: u32,
140 pub driver_digest: [u8; 32],
142}
143
144impl PipelineDeviceFingerprint {
145 #[must_use]
147 pub fn from_parts(vendor: u32, device: u32, revision: &str, revision_extra: &str) -> Self {
148 let mut hasher = blake3::Hasher::new();
149 hasher.update(b"vyre-pipeline-device-fingerprint-v1\0");
150 hasher.update(revision.as_bytes());
151 hasher.update(b"\0extra\0");
152 hasher.update(revision_extra.as_bytes());
153 Self {
154 vendor,
155 device,
156 driver_digest: *hasher.finalize().as_bytes(),
157 }
158 }
159
160 #[must_use]
162 pub fn cache_key(self, program_digest: [u8; 32]) -> [u8; 32] {
163 let mut hasher = blake3::Hasher::new();
164 hasher.update(b"vyre-disk-pipeline-cache-key-v1\0program\0");
165 hasher.update(&program_digest);
166 hasher.update(b"\0vendor\0");
167 hasher.update(&self.vendor.to_le_bytes());
168 hasher.update(b"\0device\0");
169 hasher.update(&self.device.to_le_bytes());
170 hasher.update(b"\0driver\0");
171 hasher.update(&self.driver_digest);
172 *hasher.finalize().as_bytes()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::{dispatch_policy_cache_digest, update_dispatch_policy_cache_hash};
179 use crate::backend::DispatchConfig;
180
181 #[test]
182 fn dispatch_policy_cache_digest_matches_shared_hasher_for_generated_configs() {
183 for case in 0..4096u32 {
184 let mut config = DispatchConfig::default();
185 if case & 1 != 0 {
186 config.ulp_budget = Some((case as u8).wrapping_mul(17).wrapping_add(1));
187 }
188 if case & 2 != 0 {
189 config.workgroup_override = Some([
190 1 + (case & 255),
191 1 + ((case.rotate_left(7) >> 3) & 31),
192 1 + ((case.rotate_right(5) >> 2) & 7),
193 ]);
194 }
195
196 let mut hasher = blake3::Hasher::new();
197 update_dispatch_policy_cache_hash(&mut hasher, &config);
198 assert_eq!(
199 dispatch_policy_cache_digest(&config),
200 *hasher.finalize().as_bytes(),
201 "Fix: dispatch-policy digest must stay single-sourced through update_dispatch_policy_cache_hash for generated case {case}."
202 );
203 }
204 }
205}
206
207pub(super) fn push_debug_option_u8(out: &mut String, value: Option<u8>) {
208 match value {
209 Some(value) => {
210 out.push_str("Some(");
211 push_decimal_u8(out, value);
212 out.push(')');
213 }
214 None => out.push_str("None"),
215 }
216}
217
218pub(super) fn push_debug_option_workgroup(out: &mut String, value: Option<[u32; 3]>) {
219 match value {
220 Some([x, y, z]) => {
221 out.push_str("Some([");
222 push_decimal_u32(out, x);
223 out.push_str(", ");
224 push_decimal_u32(out, y);
225 out.push_str(", ");
226 push_decimal_u32(out, z);
227 out.push_str("])");
228 }
229 None => out.push_str("None"),
230 }
231}
232
233pub(super) fn push_decimal_u8(out: &mut String, value: u8) {
234 push_decimal_u32(out, u32::from(value));
235}
236
237pub(super) fn push_decimal_u32(out: &mut String, value: u32) {
238 let mut buf = [0_u8; 10];
239 let mut n = value;
240 let mut i = buf.len();
241 if n == 0 {
242 out.push('0');
243 return;
244 }
245 while n > 0 {
246 i -= 1;
247 buf[i] = b'0' + (n % 10) as u8;
248 n /= 10;
249 }
250 out.push_str(std::str::from_utf8(&buf[i..]).unwrap_or("0"));
251}