1use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
4use crate::engine::decode::{DecodeFormat, DecodeRules, DecodedRegion};
5use crate::runtime::cache::{BufferPool, PooledBuffer};
6use crate::runtime::{bg_entry, compile_compute_pipeline};
7use vyre::{Error, Result};
8use bytemuck::{Pod, Zeroable};
9use std::sync::mpsc;
10
11pub(crate) fn dispatch_decode(
14 format: DecodeFormat,
15 input: &[u8],
16 rules: &DecodeRules,
17 command_encoder: Option<&mut wgpu::CommandEncoder>,
18) -> Result<Vec<DecodedRegion>> {
19 if input.is_empty() {
20 return Ok(Vec::new());
21 }
22 if input.len() > MAX_DECODE_INPUT_BYTES {
23 return Err(Error::Decode {
24 message: format!(
25 "decode input is {} bytes, exceeding {MAX_DECODE_INPUT_BYTES}. Fix: split the input before GPU decode dispatch.",
26 input.len()
27 ),
28 });
29 }
30 let (device, queue) = crate::runtime::cached_device()?;
31 let input_len = u32::try_from(input.len()).map_err(|source| Error::Decode {
32 message: format!("input size {} exceeds u32::MAX: {source}. Fix: split the decode input before GPU dispatch.", input.len()),
33 })?;
34 validate_gpu_sizes(device, input_len, input.len())?;
35 let max_regions = input_len.min(MAX_DECODE_REGIONS);
36 let region_bytes = usize::try_from(max_regions)
37 .map_err(|source| Error::Decode {
38 message: format!("decode max_regions {max_regions} cannot fit usize: {source}. Fix: run on a supported target."),
39 })?
40 .checked_mul(size_of::<RegionMeta>())
41 .ok_or_else(|| Error::Decode {
42 message: "decode regions buffer size overflow. Fix: split the decode input before GPU dispatch.".to_string(),
43 })?;
44
45 let params = Params {
46 input_len,
47 min_run: format.min_run(rules),
48 max_regions,
49 output_size: input_len,
50 };
51 let pool = BufferPool::global();
52 let input_bytes = align_storage_bytes(input.len())?;
53 let input_buf = pool.acquire(
54 device,
55 "vyre decode input",
56 input_bytes,
57 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
58 )?;
59 queue.write_buffer(&input_buf, 0, input);
60 write_zero_padding(queue, &input_buf, input.len(), input_bytes)?;
61 let regions_buf = zeroed_storage(device, "vyre decode regions", region_bytes.max(16))?;
62 let output_bytes = input.len().checked_mul(4).ok_or_else(|| Error::Decode {
63 message:
64 "decode output buffer size overflow. Fix: split the decode input before GPU dispatch."
65 .to_string(),
66 })?;
67 let output_buf = zeroed_storage(device, "vyre decode output", output_bytes.max(16))?;
68 let counters_buf = zeroed_storage(device, "vyre decode counters", 16)?;
69 let params_array = [params];
70 let params_bytes = safe_bytes_of_slice(¶ms_array);
71 let params_buf = pool.acquire(
72 device,
73 "vyre decode params",
74 u64::try_from(params_bytes.len()).map_err(|source| Error::Decode {
75 message: format!("decode params buffer size cannot fit u64: {source}. Fix: run on a supported target."),
76 })?,
77 wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
78 )?;
79 queue.write_buffer(¶ms_buf, 0, params_bytes);
80
81 vyre::ops::registry::gate::verify_certificate(format.op_id()).map_err(|source| Error::Decode {
82 message: source.to_string(),
83 })?;
84
85 let wgsl = format.wgsl();
86 let pipeline = compile_compute_pipeline(device, format.label(), &wgsl, "main")?;
87 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
88 label: Some("vyre decode bind group"),
89 layout: &pipeline.get_bind_group_layout(0),
90 entries: &[
91 bg_entry(0, &input_buf),
92 bg_entry(1, ®ions_buf),
93 bg_entry(2, &output_buf),
94 bg_entry(3, &counters_buf),
95 bg_entry(4, ¶ms_buf),
96 ],
97 });
98
99 let mut owned_encoder = command_encoder.is_none().then(|| {
100 device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
101 label: Some("vyre decode encoder"),
102 })
103 });
104 let encoder = if let Some(encoder) = command_encoder {
105 encoder
106 } else {
107 owned_encoder
108 .as_mut()
109 .expect("owned encoder must be present when command_encoder is omitted")
110 };
111 encoder.clear_buffer(®ions_buf, 0, Some(regions_buf.size()));
112 encoder.clear_buffer(&output_buf, 0, Some(output_buf.size()));
113 encoder.clear_buffer(&counters_buf, 0, Some(counters_buf.size()));
114 {
115 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
116 label: Some("vyre decode pass"),
117 timestamp_writes: None,
118 });
119 pass.set_pipeline(&pipeline);
120 pass.set_bind_group(0, &bind_group, &[]);
121 pass.dispatch_workgroups(input_len.div_ceil(WORKGROUP_SIZE), 1, 1);
122 }
123 let counters_readback = readback_buffer(device, encoder, &counters_buf, 16)?;
124 let regions_readback = readback_buffer(
125 device,
126 encoder,
127 ®ions_buf,
128 u64::try_from(region_bytes.max(16)).map_err(|source| Error::Decode {
129 message: format!("decode region readback size cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."),
130 })?,
131 )?;
132 let output_readback = readback_buffer(
133 device,
134 encoder,
135 &output_buf,
136 u64::try_from(output_bytes.max(16)).map_err(|source| Error::Decode {
137 message: format!("decode output readback size cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."),
138 })?,
139 )?;
140 let Some(owned_encoder) = owned_encoder else {
141 return Err(Error::Decode {
142 message: "dispatch_decode was called with an external command encoder, but this API returns decoded readback data that is unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred decode API that returns readback buffers.".to_string(),
143 });
144 };
145 let submission = queue.submit(Some(owned_encoder.finish()));
146 let regions = readback_regions(
147 device,
148 &counters_readback,
149 ®ions_readback,
150 &output_readback,
151 input.len(),
152 submission,
153 )?;
154 pool.release(input_buf);
155 pool.release(regions_buf);
156 pool.release(output_buf);
157 pool.release(counters_buf);
158 pool.release(params_buf);
159 pool.release(counters_readback);
160 pool.release(regions_readback);
161 pool.release(output_readback);
162 Ok(regions)
163}
164
165pub fn readback_regions(
167 device: &wgpu::Device,
168 counters_readback: &wgpu::Buffer,
169 regions_readback: &wgpu::Buffer,
170 output_readback: &wgpu::Buffer,
171 input_len: usize,
172 submission: wgpu::SubmissionIndex,
173) -> Result<Vec<DecodedRegion>> {
174 let counters = map_readback_u32(device, counters_readback, submission.clone())?;
175 let region_meta = map_readback_region_meta(device, regions_readback, submission.clone())?;
176 let output_words = map_readback_u32(device, output_readback, submission)?;
177 let max_regions = u32::try_from(input_len).map_err(|source| Error::Decode {
178 message: format!("input length {input_len} cannot fit u32 while bounding readback: {source}. Fix: split the decode input before GPU dispatch."),
179 })?.min(MAX_DECODE_REGIONS);
180 let region_count = usize::try_from(counters.first().copied().unwrap_or(0).min(max_regions))
181 .map_err(|source| Error::Decode {
182 message: format!("region count cannot fit usize: {source}. Fix: reject this GPU readback on this platform."),
183 })?;
184 if region_count
185 > usize::try_from(MAX_DECODE_REGIONS).map_err(|source| Error::Decode {
186 message: format!(
187 "MAX_DECODE_REGIONS cannot fit usize: {source}. Fix: run on a supported target."
188 ),
189 })?
190 {
191 return Err(Error::Decode {
192 message: format!(
193 "GPU region count {region_count} exceeds {MAX_DECODE_REGIONS}. Fix: reject this malformed decoder output."
194 ),
195 });
196 }
197 let mut decoded = Vec::with_capacity(region_count);
198 for meta in region_meta.into_iter().take(region_count) {
199 if meta.src_len == 0 || meta.dst_len == 0 {
200 continue;
201 }
202 let src_offset = usize::try_from(meta.src_offset).map_err(|source| Error::Decode {
203 message: format!("source offset {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.src_offset),
204 })?;
205 let src_len = usize::try_from(meta.src_len).map_err(|source| Error::Decode {
206 message: format!("source length {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.src_len),
207 })?;
208 let src_end = src_offset.checked_add(src_len).ok_or_else(|| Error::Decode {
209 message: "GPU readback src_offset+src_len overflow. Fix: reject this malformed decoder output.".to_string(),
210 })?;
211 if src_end > input_len {
212 return Err(Error::Decode {
213 message: format!(
214 "GPU source region [{src_offset}, {src_end}) exceeds input length {input_len}. Fix: reject this malformed decoder output and inspect the decode shader."
215 ),
216 });
217 }
218 let dst_start = usize::try_from(meta.dst_offset).map_err(|source| Error::Decode {
219 message: format!("output offset {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.dst_offset),
220 })?;
221 let dst_end = dst_start
222 .checked_add(usize::try_from(meta.dst_len).map_err(|source| Error::Decode {
223 message: format!("output length {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.dst_len),
224 })?)
225 .ok_or_else(|| Error::Decode {
226 message: "output region overflow during readback. Fix: reject this malformed decoder output.".to_string(),
227 })?;
228 if dst_end > output_words.len() {
229 return Err(Error::Decode {
230 message: "shader emitted output beyond allocated readback storage. Fix: reject this malformed decoder output and inspect the decode shader.".to_string(),
231 });
232 }
233 let decoded_bytes = output_words[dst_start..dst_end]
234 .iter()
235 .map(|word| {
236 u8::try_from(*word & 0xff).map_err(|source| Error::Decode {
237 message: format!("masked output byte could not fit u8: {source}. Fix: report this impossible conversion failure."),
238 })
239 })
240 .collect::<Result<Vec<_>>>()?;
241 decoded.push(DecodedRegion {
242 offset: src_offset,
243 length: src_len,
244 decoded_bytes,
245 });
246 }
247 Ok(decoded)
248}
249
250pub const WORKGROUP_SIZE: u32 = 64;
252
253pub const REGION_META_SIZE: u64 = 16;
255
256pub const MAX_DECODE_INPUT_BYTES: usize = 64 * 1024 * 1024;
262
263pub const MAX_DECODE_REGIONS: u32 = 1_000_000;
268
269#[repr(C)]
270#[derive(Clone, Copy, Pod, Zeroable)]
271pub struct Params {
273 input_len: u32,
274 min_run: u32,
275 max_regions: u32,
276 output_size: u32,
277}
278
279#[repr(C)]
280#[derive(Clone, Copy, Pod, Zeroable)]
281pub struct RegionMeta {
283 src_offset: u32,
284 src_len: u32,
285 dst_offset: u32,
286 dst_len: u32,
287}
288
289pub fn validate_gpu_sizes(device: &wgpu::Device, input_len: u32, input_size: usize) -> Result<()> {
291 let limits = device.limits();
292 let gpu_limit = u64::from(limits.max_storage_buffer_binding_size).min(limits.max_buffer_size);
293 let regions_bytes = u64::from(input_len.min(MAX_DECODE_REGIONS)) * REGION_META_SIZE;
294 let output_bytes = u64::from(input_len) * 4;
295 if regions_bytes > gpu_limit || output_bytes > gpu_limit {
296 return Err(Error::Gpu {
297 message: format!(
298 "input size {input_size} exceeds GPU buffer limit ({regions_bytes} byte regions buffer, {output_bytes} byte output buffer, {gpu_limit} limit). Fix: split the input or run on an adapter with larger storage buffers."
299 ),
300 });
301 }
302 Ok(())
303}
304
305fn align_storage_bytes(len: usize) -> Result<u64> {
306 let aligned = len.max(1).next_multiple_of(4);
307 u64::try_from(aligned).map_err(|source| Error::Decode {
308 message: format!(
309 "decode input buffer size {aligned} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
310 ),
311 })
312}
313
314fn write_zero_padding(
315 queue: &wgpu::Queue,
316 buffer: &wgpu::Buffer,
317 written: usize,
318 total: u64,
319) -> Result<()> {
320 let written_u64 = u64::try_from(written).map_err(|source| Error::Decode {
321 message: format!(
322 "decode written byte count {written} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
323 ),
324 })?;
325 if written_u64 >= total {
326 return Ok(());
327 }
328 let padding_len = usize::try_from(total - written_u64).map_err(|source| Error::Decode {
329 message: format!(
330 "decode zero padding length cannot fit usize: {source}. Fix: split the decode input before GPU dispatch."
331 ),
332 })?;
333 let padding = [0u8; 4];
334 queue.write_buffer(buffer, written_u64, &padding[..padding_len]);
335 Ok(())
336}
337
338pub fn map_readback_u32(
340 device: &wgpu::Device,
341 buffer: &wgpu::Buffer,
342 submission: wgpu::SubmissionIndex,
343) -> Result<Vec<u32>> {
344 let slice = buffer.slice(..);
345 let (sender, receiver) = mpsc::channel();
346 slice.map_async(wgpu::MapMode::Read, move |result| {
347 if let Err(send_err) = sender.send(result) {
348 tracing::warn!(
349 ?send_err,
350 "decode readback receiver dropped before map_async result delivery"
351 );
352 }
353 });
354 match device.poll(wgpu::Maintain::wait_for(submission)) {
355 wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
356 }
357 receiver
358 .recv()
359 .map_err(|source| Error::Gpu {
360 message: format!("readback channel closed unexpectedly: {source}. Fix: keep the decode readback receiver alive until map_async completes."),
361 })?
362 .map_err(|error| Error::Gpu {
363 message: format!("map_async failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
364 })?;
365 let mapped = slice.get_mapped_range();
366 let out = safe_cast_slice::<u32>(&mapped)
367 .map_err(|error| Error::Decode {
368 message: format!(
369 "safe cast failed in map_readback_u32: {error}. Fix: ensure the readback buffer is aligned and sized correctly."
370 ),
371 })?
372 .to_vec();
373 drop(mapped);
374 buffer.unmap();
375 Ok(out)
376}
377
378pub fn map_readback_region_meta(
380 device: &wgpu::Device,
381 buffer: &wgpu::Buffer,
382 submission: wgpu::SubmissionIndex,
383) -> Result<Vec<RegionMeta>> {
384 let words = map_readback_u32(device, buffer, submission)?;
385 let region_words = usize::try_from(MAX_DECODE_REGIONS)
386 .map_err(|source| Error::Decode {
387 message: format!(
388 "MAX_DECODE_REGIONS cannot fit usize: {source}. Fix: run on a supported target."
389 ),
390 })?
391 .checked_mul(4)
392 .ok_or_else(|| Error::Decode {
393 message: "decode region word bound overflow. Fix: lower MAX_DECODE_REGIONS."
394 .to_string(),
395 })?;
396 if words.len() > region_words {
397 return Err(Error::Decode {
398 message: format!(
399 "decode region metadata contains {} u32 words, exceeding {region_words}. Fix: split the input before GPU decode dispatch.",
400 words.len()
401 ),
402 });
403 }
404 let mut regions = Vec::with_capacity(words.len() / 4);
405 for chunk in words.chunks_exact(4) {
406 regions.push(RegionMeta {
407 src_offset: chunk[0],
408 src_len: chunk[1],
409 dst_offset: chunk[2],
410 dst_len: chunk[3],
411 });
412 }
413 Ok(regions)
414}
415
416pub fn zeroed_storage(device: &wgpu::Device, label: &str, bytes: usize) -> Result<PooledBuffer> {
418 BufferPool::global().acquire(
419 device,
420 label,
421 u64::try_from(bytes).map_err(|source| Error::Decode {
422 message: format!(
423 "decode zeroed storage size {bytes} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
424 ),
425 })?,
426 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
427 )
428}
429
430pub fn readback_buffer(
432 device: &wgpu::Device,
433 encoder: &mut wgpu::CommandEncoder,
434 source: &wgpu::Buffer,
435 size: u64,
436) -> Result<PooledBuffer> {
437 let readback = BufferPool::global().acquire(
438 device,
439 "vyre decode readback",
440 size,
441 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
442 )?;
443 encoder.copy_buffer_to_buffer(source, 0, &readback, 0, size);
444 Ok(readback)
445}