1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 AdapterInfo,
5};
6
7pub const HEADER_LENGTH: usize = size_of::<PipelineCacheHeader>();
8
9#[derive(Debug, PartialEq, Eq, Clone, Error)]
10#[non_exhaustive]
11pub enum PipelineCacheValidationError {
12 #[error("The pipeline cache data was truncated")]
13 Truncated,
14 #[error("The pipeline cache data was longer than recorded")]
15 Extended,
17 #[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")]
18 Corrupted,
19 #[error("The pipeline cacha data was out of date and so cannot be safely used")]
20 Outdated,
21 #[error("The cache data was created for a different device")]
22 DeviceMismatch,
23 #[error("Pipeline cacha data was created for a future version of wgpu")]
24 Unsupported,
25}
26
27impl PipelineCacheValidationError {
28 pub fn was_avoidable(&self) -> bool {
31 match self {
32 PipelineCacheValidationError::DeviceMismatch => true,
33 PipelineCacheValidationError::Truncated
34 | PipelineCacheValidationError::Unsupported
35 | PipelineCacheValidationError::Extended
36 | PipelineCacheValidationError::Outdated
38 | PipelineCacheValidationError::Corrupted => false,
39 }
40 }
41}
42
43impl WebGpuError for PipelineCacheValidationError {
44 fn webgpu_error_type(&self) -> ErrorType {
45 ErrorType::Validation
46 }
47}
48
49pub fn validate_pipeline_cache<'d>(
51 cache_data: &'d [u8],
52 adapter: &AdapterInfo,
53 validation_key: [u8; 16],
54) -> Result<&'d [u8], PipelineCacheValidationError> {
55 let adapter_key = adapter_key(adapter)?;
56 let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else {
57 return Err(PipelineCacheValidationError::Truncated);
58 };
59 if header.magic != MAGIC {
60 return Err(PipelineCacheValidationError::Corrupted);
61 }
62 if header.header_version != HEADER_VERSION {
63 return Err(PipelineCacheValidationError::Outdated);
64 }
65 if header.cache_abi != ABI {
66 return Err(PipelineCacheValidationError::Outdated);
67 }
68 if header.backend != adapter.backend as u8 {
69 return Err(PipelineCacheValidationError::DeviceMismatch);
70 }
71 if header.adapter_key != adapter_key {
72 return Err(PipelineCacheValidationError::DeviceMismatch);
73 }
74 if header.validation_key != validation_key {
75 return Err(PipelineCacheValidationError::Outdated);
79 }
80 let data_size: usize = header
81 .data_size
82 .try_into()
83 .map_err(|_| PipelineCacheValidationError::Corrupted)?;
86 if remaining_data.len() < data_size {
87 return Err(PipelineCacheValidationError::Truncated);
88 }
89 if remaining_data.len() > data_size {
90 return Err(PipelineCacheValidationError::Extended);
91 }
92 if header.hash_space != HASH_SPACE_VALUE {
93 return Err(PipelineCacheValidationError::Corrupted);
94 }
95 Ok(remaining_data)
96}
97
98pub fn add_cache_header(
99 in_region: &mut [u8],
100 data: &[u8],
101 adapter: &AdapterInfo,
102 validation_key: [u8; 16],
103) {
104 assert_eq!(in_region.len(), HEADER_LENGTH);
105 let header = PipelineCacheHeader {
106 adapter_key: adapter_key(adapter)
107 .expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"),
108 backend: adapter.backend as u8,
109 cache_abi: ABI,
110 magic: MAGIC,
111 header_version: HEADER_VERSION,
112 validation_key,
113 hash_space: HASH_SPACE_VALUE,
114 data_size: data
115 .len()
116 .try_into()
117 .expect("Cache larger than u64::MAX bytes"),
118 };
119 header.write(in_region);
120}
121
122const MAGIC: [u8; 8] = *b"WGPUPLCH";
123const HEADER_VERSION: u32 = 1;
124const ABI: u32 = size_of::<*const ()>() as u32;
125
126const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210;
136
137#[repr(C)]
138#[derive(PartialEq, Eq)]
139struct PipelineCacheHeader {
140 magic: [u8; 8],
143 header_version: u32,
150 cache_abi: u32,
154 backend: u8,
156 adapter_key: [u8; 15],
161 validation_key: [u8; 16],
165 data_size: u64,
167 hash_space: u64,
174}
175
176impl PipelineCacheHeader {
177 fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> {
178 let mut reader = Reader {
179 data,
180 total_read: 0,
181 };
182 let magic = reader.read_array()?;
183 let header_version = reader.read_u32()?;
184 let cache_abi = reader.read_u32()?;
185 let backend = reader.read_byte()?;
186 let adapter_key = reader.read_array()?;
187 let validation_key = reader.read_array()?;
188 let data_size = reader.read_u64()?;
189 let data_hash = reader.read_u64()?;
190
191 assert_eq!(reader.total_read, size_of::<PipelineCacheHeader>());
192
193 Some((
194 PipelineCacheHeader {
195 magic,
196 header_version,
197 cache_abi,
198 backend,
199 adapter_key,
200 validation_key,
201 data_size,
202 hash_space: data_hash,
203 },
204 reader.data,
205 ))
206 }
207
208 fn write(&self, into: &mut [u8]) -> Option<()> {
209 let mut writer = Writer { data: into };
210 writer.write_array(&self.magic)?;
211 writer.write_u32(self.header_version)?;
212 writer.write_u32(self.cache_abi)?;
213 writer.write_byte(self.backend)?;
214 writer.write_array(&self.adapter_key)?;
215 writer.write_array(&self.validation_key)?;
216 writer.write_u64(self.data_size)?;
217 writer.write_u64(self.hash_space)?;
218
219 assert_eq!(writer.data.len(), 0);
220 Some(())
221 }
222}
223
224fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> {
225 match adapter.backend {
226 wgt::Backend::Vulkan => {
227 let v: [u8; 4] = adapter.vendor.to_be_bytes();
230 let d: [u8; 4] = adapter.device.to_be_bytes();
231 let adapter = [
232 255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255,
233 ];
234 Ok(adapter)
235 }
236 _ => Err(PipelineCacheValidationError::Unsupported),
237 }
238}
239
240struct Reader<'a> {
241 data: &'a [u8],
242 total_read: usize,
243}
244
245impl<'a> Reader<'a> {
246 fn read_byte(&mut self) -> Option<u8> {
247 let res = *self.data.first()?;
248 self.total_read += 1;
249 self.data = &self.data[1..];
250 Some(res)
251 }
252 fn read_array<const N: usize>(&mut self) -> Option<[u8; N]> {
253 if N > self.data.len() {
255 return None;
256 }
257 let (start, data) = self.data.split_at(N);
258 self.total_read += N;
259 self.data = data;
260 Some(start.try_into().expect("off-by-one-error in array size"))
261 }
262
263 fn read_u32(&mut self) -> Option<u32> {
267 self.read_array().map(u32::from_be_bytes)
268 }
269 fn read_u64(&mut self) -> Option<u64> {
270 self.read_array().map(u64::from_be_bytes)
271 }
272}
273
274struct Writer<'a> {
275 data: &'a mut [u8],
276}
277
278impl<'a> Writer<'a> {
279 fn write_byte(&mut self, byte: u8) -> Option<()> {
280 self.write_array(&[byte])
281 }
282 fn write_array<const N: usize>(&mut self, array: &[u8; N]) -> Option<()> {
283 if N > self.data.len() {
285 return None;
286 }
287 let data = core::mem::take(&mut self.data);
288 let (start, data) = data.split_at_mut(N);
289 self.data = data;
290 start.copy_from_slice(array);
291 Some(())
292 }
293
294 fn write_u32(&mut self, value: u32) -> Option<()> {
298 self.write_array(&value.to_be_bytes())
299 }
300 fn write_u64(&mut self, value: u64) -> Option<()> {
301 self.write_array(&value.to_be_bytes())
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use alloc::{string::String, vec::Vec};
308 use wgt::AdapterInfo;
309
310 use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH};
311
312 use super::ABI;
313
314 const _: [(); HEADER_LENGTH] = [(); 64];
316
317 const ADAPTER: AdapterInfo = AdapterInfo {
318 name: String::new(),
319 vendor: 0x0002_FEED,
320 device: 0xFEFE_FEFE,
321 device_type: wgt::DeviceType::Other,
322 driver: String::new(),
323 driver_info: String::new(),
324 backend: wgt::Backend::Vulkan,
325 };
326
327 const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888);
329 #[test]
330 fn written_header() {
331 let mut result = [0; HEADER_LENGTH];
332 super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY);
333 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
334 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
343 let expected = cache.into_iter().flatten().collect::<Vec<u8>>();
344
345 assert_eq!(result.as_slice(), expected.as_slice());
346 }
347
348 #[test]
349 fn valid_data() {
350 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
351 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
360 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
361 let expected: &[u8] = &[];
362 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
363 assert_eq!(validation_result, Ok(expected));
364 }
365 #[test]
366 fn invalid_magic() {
367 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
368 *b"NOT_WGPU", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
377 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
378 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
379 assert_eq!(validation_result, Err(E::Corrupted));
380 }
381
382 #[test]
383 fn wrong_version() {
384 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
385 *b"WGPUPLCH", [0, 0, 0, 2, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
394 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
395 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
396 assert_eq!(validation_result, Err(E::Outdated));
397 }
398 #[test]
399 fn wrong_abi() {
400 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
401 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, 14], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
411 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
412 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
413 assert_eq!(validation_result, Err(E::Outdated));
414 }
415
416 #[test]
417 fn wrong_backend() {
418 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
419 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [2, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
428 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
429 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
430 assert_eq!(validation_result, Err(E::DeviceMismatch));
431 }
432 #[test]
433 fn wrong_adapter() {
434 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
435 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0x00], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
444 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
445 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
446 assert_eq!(validation_result, Err(E::DeviceMismatch));
447 }
448 #[test]
449 fn wrong_validation() {
450 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
451 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_00000000u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
460 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
461 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
462 assert_eq!(validation_result, Err(E::Outdated));
463 }
464 #[test]
465 fn too_little_data() {
466 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
467 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
476 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
477 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
478 assert_eq!(validation_result, Err(E::Truncated));
479 }
480 #[test]
481 fn not_no_data() {
482 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
483 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 100u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
492 let cache = cache
493 .into_iter()
494 .flatten()
495 .chain(core::iter::repeat_n(0u8, 100))
496 .collect::<Vec<u8>>();
497 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
498 let expected: &[u8] = &[0; 100];
499 assert_eq!(validation_result, Ok(expected));
500 }
501 #[test]
502 fn too_much_data() {
503 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
504 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
513 let cache = cache
514 .into_iter()
515 .flatten()
516 .chain(core::iter::repeat_n(0u8, 200))
517 .collect::<Vec<u8>>();
518 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
519 assert_eq!(validation_result, Err(E::Extended));
520 }
521 #[test]
522 fn wrong_hash() {
523 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
524 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0x00000000_00000000u64.to_be_bytes(), ];
533 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
534 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
535 assert_eq!(validation_result, Err(E::Corrupted));
536 }
537}