Skip to main content

vyre_wgpu/engine/decode/dispatch/
gpu.rs

1//! GPU dispatch for the decode engine.
2
3use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
4use crate::engine::decode::{DecodeFormat, DecodeRules, DecodedRegion};
5use crate::runtime::cache::{BufferPool, PooledBuffer};
6use crate::runtime::{bg_entry, compile_compute_pipeline};
7use vyre::{Error, Result};
8use bytemuck::{Pod, Zeroable};
9use std::sync::mpsc;
10
11/// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
12/// restricted visibility audit blind spots.
13pub(crate) fn dispatch_decode(
14    format: DecodeFormat,
15    input: &[u8],
16    rules: &DecodeRules,
17    command_encoder: Option<&mut wgpu::CommandEncoder>,
18) -> Result<Vec<DecodedRegion>> {
19    if input.is_empty() {
20        return Ok(Vec::new());
21    }
22    if input.len() > MAX_DECODE_INPUT_BYTES {
23        return Err(Error::Decode {
24            message: format!(
25                "decode input is {} bytes, exceeding {MAX_DECODE_INPUT_BYTES}. Fix: split the input before GPU decode dispatch.",
26                input.len()
27            ),
28        });
29    }
30    let (device, queue) = crate::runtime::cached_device()?;
31    let input_len = u32::try_from(input.len()).map_err(|source| Error::Decode {
32        message: format!("input size {} exceeds u32::MAX: {source}. Fix: split the decode input before GPU dispatch.", input.len()),
33    })?;
34    validate_gpu_sizes(device, input_len, input.len())?;
35    let max_regions = input_len.min(MAX_DECODE_REGIONS);
36    let region_bytes = usize::try_from(max_regions)
37        .map_err(|source| Error::Decode {
38            message: format!("decode max_regions {max_regions} cannot fit usize: {source}. Fix: run on a supported target."),
39        })?
40        .checked_mul(size_of::<RegionMeta>())
41        .ok_or_else(|| Error::Decode {
42            message: "decode regions buffer size overflow. Fix: split the decode input before GPU dispatch.".to_string(),
43        })?;
44
45    let params = Params {
46        input_len,
47        min_run: format.min_run(rules),
48        max_regions,
49        output_size: input_len,
50    };
51    let pool = BufferPool::global();
52    let input_bytes = align_storage_bytes(input.len())?;
53    let input_buf = pool.acquire(
54        device,
55        "vyre decode input",
56        input_bytes,
57        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
58    )?;
59    queue.write_buffer(&input_buf, 0, input);
60    write_zero_padding(queue, &input_buf, input.len(), input_bytes)?;
61    let regions_buf = zeroed_storage(device, "vyre decode regions", region_bytes.max(16))?;
62    let output_bytes = input.len().checked_mul(4).ok_or_else(|| Error::Decode {
63        message:
64            "decode output buffer size overflow. Fix: split the decode input before GPU dispatch."
65                .to_string(),
66    })?;
67    let output_buf = zeroed_storage(device, "vyre decode output", output_bytes.max(16))?;
68    let counters_buf = zeroed_storage(device, "vyre decode counters", 16)?;
69    let params_array = [params];
70    let params_bytes = safe_bytes_of_slice(&params_array);
71    let params_buf = pool.acquire(
72        device,
73        "vyre decode params",
74        u64::try_from(params_bytes.len()).map_err(|source| Error::Decode {
75            message: format!("decode params buffer size cannot fit u64: {source}. Fix: run on a supported target."),
76        })?,
77        wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
78    )?;
79    queue.write_buffer(&params_buf, 0, params_bytes);
80
81    vyre::ops::registry::gate::verify_certificate(format.op_id()).map_err(|source| Error::Decode {
82        message: source.to_string(),
83    })?;
84
85    let wgsl = format.wgsl();
86    let pipeline = compile_compute_pipeline(device, format.label(), &wgsl, "main")?;
87    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
88        label: Some("vyre decode bind group"),
89        layout: &pipeline.get_bind_group_layout(0),
90        entries: &[
91            bg_entry(0, &input_buf),
92            bg_entry(1, &regions_buf),
93            bg_entry(2, &output_buf),
94            bg_entry(3, &counters_buf),
95            bg_entry(4, &params_buf),
96        ],
97    });
98
99    let mut owned_encoder = command_encoder.is_none().then(|| {
100        device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
101            label: Some("vyre decode encoder"),
102        })
103    });
104    let encoder = if let Some(encoder) = command_encoder {
105        encoder
106    } else {
107        owned_encoder
108            .as_mut()
109            .expect("owned encoder must be present when command_encoder is omitted")
110    };
111    encoder.clear_buffer(&regions_buf, 0, Some(regions_buf.size()));
112    encoder.clear_buffer(&output_buf, 0, Some(output_buf.size()));
113    encoder.clear_buffer(&counters_buf, 0, Some(counters_buf.size()));
114    {
115        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
116            label: Some("vyre decode pass"),
117            timestamp_writes: None,
118        });
119        pass.set_pipeline(&pipeline);
120        pass.set_bind_group(0, &bind_group, &[]);
121        pass.dispatch_workgroups(input_len.div_ceil(WORKGROUP_SIZE), 1, 1);
122    }
123    let counters_readback = readback_buffer(device, encoder, &counters_buf, 16)?;
124    let regions_readback = readback_buffer(
125        device,
126        encoder,
127        &regions_buf,
128        u64::try_from(region_bytes.max(16)).map_err(|source| Error::Decode {
129            message: format!("decode region readback size cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."),
130        })?,
131    )?;
132    let output_readback = readback_buffer(
133        device,
134        encoder,
135        &output_buf,
136        u64::try_from(output_bytes.max(16)).map_err(|source| Error::Decode {
137            message: format!("decode output readback size cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."),
138        })?,
139    )?;
140    let Some(owned_encoder) = owned_encoder else {
141        return Err(Error::Decode {
142            message: "dispatch_decode was called with an external command encoder, but this API returns decoded readback data that is unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred decode API that returns readback buffers.".to_string(),
143        });
144    };
145    let submission = queue.submit(Some(owned_encoder.finish()));
146    let regions = readback_regions(
147        device,
148        &counters_readback,
149        &regions_readback,
150        &output_readback,
151        input.len(),
152        submission,
153    )?;
154    pool.release(input_buf);
155    pool.release(regions_buf);
156    pool.release(output_buf);
157    pool.release(counters_buf);
158    pool.release(params_buf);
159    pool.release(counters_readback);
160    pool.release(regions_readback);
161    pool.release(output_readback);
162    Ok(regions)
163}
164
165/// `readback_regions` function.
166pub fn readback_regions(
167    device: &wgpu::Device,
168    counters_readback: &wgpu::Buffer,
169    regions_readback: &wgpu::Buffer,
170    output_readback: &wgpu::Buffer,
171    input_len: usize,
172    submission: wgpu::SubmissionIndex,
173) -> Result<Vec<DecodedRegion>> {
174    let counters = map_readback_u32(device, counters_readback, submission.clone())?;
175    let region_meta = map_readback_region_meta(device, regions_readback, submission.clone())?;
176    let output_words = map_readback_u32(device, output_readback, submission)?;
177    let max_regions = u32::try_from(input_len).map_err(|source| Error::Decode {
178        message: format!("input length {input_len} cannot fit u32 while bounding readback: {source}. Fix: split the decode input before GPU dispatch."),
179    })?.min(MAX_DECODE_REGIONS);
180    let region_count = usize::try_from(counters.first().copied().unwrap_or(0).min(max_regions))
181        .map_err(|source| Error::Decode {
182            message: format!("region count cannot fit usize: {source}. Fix: reject this GPU readback on this platform."),
183        })?;
184    if region_count
185        > usize::try_from(MAX_DECODE_REGIONS).map_err(|source| Error::Decode {
186            message: format!(
187                "MAX_DECODE_REGIONS cannot fit usize: {source}. Fix: run on a supported target."
188            ),
189        })?
190    {
191        return Err(Error::Decode {
192            message: format!(
193                "GPU region count {region_count} exceeds {MAX_DECODE_REGIONS}. Fix: reject this malformed decoder output."
194            ),
195        });
196    }
197    let mut decoded = Vec::with_capacity(region_count);
198    for meta in region_meta.into_iter().take(region_count) {
199        if meta.src_len == 0 || meta.dst_len == 0 {
200            continue;
201        }
202        let src_offset = usize::try_from(meta.src_offset).map_err(|source| Error::Decode {
203            message: format!("source offset {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.src_offset),
204        })?;
205        let src_len = usize::try_from(meta.src_len).map_err(|source| Error::Decode {
206            message: format!("source length {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.src_len),
207        })?;
208        let src_end = src_offset.checked_add(src_len).ok_or_else(|| Error::Decode {
209            message: "GPU readback src_offset+src_len overflow. Fix: reject this malformed decoder output.".to_string(),
210        })?;
211        if src_end > input_len {
212            return Err(Error::Decode {
213                message: format!(
214                    "GPU source region [{src_offset}, {src_end}) exceeds input length {input_len}. Fix: reject this malformed decoder output and inspect the decode shader."
215                ),
216            });
217        }
218        let dst_start = usize::try_from(meta.dst_offset).map_err(|source| Error::Decode {
219            message: format!("output offset {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.dst_offset),
220        })?;
221        let dst_end = dst_start
222            .checked_add(usize::try_from(meta.dst_len).map_err(|source| Error::Decode {
223                message: format!("output length {} cannot fit usize: {source}. Fix: reject this GPU readback on this platform.", meta.dst_len),
224            })?)
225            .ok_or_else(|| Error::Decode {
226                message: "output region overflow during readback. Fix: reject this malformed decoder output.".to_string(),
227            })?;
228        if dst_end > output_words.len() {
229            return Err(Error::Decode {
230                message: "shader emitted output beyond allocated readback storage. Fix: reject this malformed decoder output and inspect the decode shader.".to_string(),
231            });
232        }
233        let decoded_bytes = output_words[dst_start..dst_end]
234            .iter()
235            .map(|word| {
236                u8::try_from(*word & 0xff).map_err(|source| Error::Decode {
237                    message: format!("masked output byte could not fit u8: {source}. Fix: report this impossible conversion failure."),
238                })
239            })
240            .collect::<Result<Vec<_>>>()?;
241        decoded.push(DecodedRegion {
242            offset: src_offset,
243            length: src_len,
244            decoded_bytes,
245        });
246    }
247    Ok(decoded)
248}
249
250/// `WORKGROUP_SIZE` constant.
251pub const WORKGROUP_SIZE: u32 = 64;
252
253/// `REGION_META_SIZE` constant.
254pub const REGION_META_SIZE: u64 = 16;
255
256/// Maximum input bytes accepted by one GPU decode dispatch.
257///
258/// I10: decode allocates packed input, region metadata, and output buffers
259/// from the input length. The 64 MiB cap rejects oversized requests before
260/// `Vec::with_capacity` or zeroed GPU staging allocations reserve memory.
261pub const MAX_DECODE_INPUT_BYTES: usize = 64 * 1024 * 1024;
262
263/// Maximum decoded regions accepted from one GPU readback.
264///
265/// I10: this bounds region vector allocation even if a malformed shader
266/// counter reports one region per byte or a corrupted counter value.
267pub const MAX_DECODE_REGIONS: u32 = 1_000_000;
268
269#[repr(C)]
270#[derive(Clone, Copy, Pod, Zeroable)]
271/// `Params` struct.
272pub struct Params {
273    input_len: u32,
274    min_run: u32,
275    max_regions: u32,
276    output_size: u32,
277}
278
279#[repr(C)]
280#[derive(Clone, Copy, Pod, Zeroable)]
281/// `RegionMeta` struct.
282pub struct RegionMeta {
283    src_offset: u32,
284    src_len: u32,
285    dst_offset: u32,
286    dst_len: u32,
287}
288
289/// `validate_gpu_sizes` function.
290pub fn validate_gpu_sizes(device: &wgpu::Device, input_len: u32, input_size: usize) -> Result<()> {
291    let limits = device.limits();
292    let gpu_limit = u64::from(limits.max_storage_buffer_binding_size).min(limits.max_buffer_size);
293    let regions_bytes = u64::from(input_len.min(MAX_DECODE_REGIONS)) * REGION_META_SIZE;
294    let output_bytes = u64::from(input_len) * 4;
295    if regions_bytes > gpu_limit || output_bytes > gpu_limit {
296        return Err(Error::Gpu {
297            message: format!(
298                "input size {input_size} exceeds GPU buffer limit ({regions_bytes} byte regions buffer, {output_bytes} byte output buffer, {gpu_limit} limit). Fix: split the input or run on an adapter with larger storage buffers."
299            ),
300        });
301    }
302    Ok(())
303}
304
305fn align_storage_bytes(len: usize) -> Result<u64> {
306    let aligned = len.max(1).next_multiple_of(4);
307    u64::try_from(aligned).map_err(|source| Error::Decode {
308        message: format!(
309            "decode input buffer size {aligned} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
310        ),
311    })
312}
313
314fn write_zero_padding(
315    queue: &wgpu::Queue,
316    buffer: &wgpu::Buffer,
317    written: usize,
318    total: u64,
319) -> Result<()> {
320    let written_u64 = u64::try_from(written).map_err(|source| Error::Decode {
321        message: format!(
322            "decode written byte count {written} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
323        ),
324    })?;
325    if written_u64 >= total {
326        return Ok(());
327    }
328    let padding_len = usize::try_from(total - written_u64).map_err(|source| Error::Decode {
329        message: format!(
330            "decode zero padding length cannot fit usize: {source}. Fix: split the decode input before GPU dispatch."
331        ),
332    })?;
333    let padding = [0u8; 4];
334    queue.write_buffer(buffer, written_u64, &padding[..padding_len]);
335    Ok(())
336}
337
338/// `map_readback_u32` function.
339pub fn map_readback_u32(
340    device: &wgpu::Device,
341    buffer: &wgpu::Buffer,
342    submission: wgpu::SubmissionIndex,
343) -> Result<Vec<u32>> {
344    let slice = buffer.slice(..);
345    let (sender, receiver) = mpsc::channel();
346    slice.map_async(wgpu::MapMode::Read, move |result| {
347        if let Err(send_err) = sender.send(result) {
348            tracing::warn!(
349                ?send_err,
350                "decode readback receiver dropped before map_async result delivery"
351            );
352        }
353    });
354    match device.poll(wgpu::Maintain::wait_for(submission)) {
355        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
356    }
357    receiver
358        .recv()
359        .map_err(|source| Error::Gpu {
360            message: format!("readback channel closed unexpectedly: {source}. Fix: keep the decode readback receiver alive until map_async completes."),
361        })?
362        .map_err(|error| Error::Gpu {
363            message: format!("map_async failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
364        })?;
365    let mapped = slice.get_mapped_range();
366    let out = safe_cast_slice::<u32>(&mapped)
367        .map_err(|error| Error::Decode {
368            message: format!(
369                "safe cast failed in map_readback_u32: {error}. Fix: ensure the readback buffer is aligned and sized correctly."
370            ),
371        })?
372        .to_vec();
373    drop(mapped);
374    buffer.unmap();
375    Ok(out)
376}
377
378/// `map_readback_region_meta` function.
379pub fn map_readback_region_meta(
380    device: &wgpu::Device,
381    buffer: &wgpu::Buffer,
382    submission: wgpu::SubmissionIndex,
383) -> Result<Vec<RegionMeta>> {
384    let words = map_readback_u32(device, buffer, submission)?;
385    let region_words = usize::try_from(MAX_DECODE_REGIONS)
386        .map_err(|source| Error::Decode {
387            message: format!(
388                "MAX_DECODE_REGIONS cannot fit usize: {source}. Fix: run on a supported target."
389            ),
390        })?
391        .checked_mul(4)
392        .ok_or_else(|| Error::Decode {
393            message: "decode region word bound overflow. Fix: lower MAX_DECODE_REGIONS."
394                .to_string(),
395        })?;
396    if words.len() > region_words {
397        return Err(Error::Decode {
398            message: format!(
399                "decode region metadata contains {} u32 words, exceeding {region_words}. Fix: split the input before GPU decode dispatch.",
400                words.len()
401            ),
402        });
403    }
404    let mut regions = Vec::with_capacity(words.len() / 4);
405    for chunk in words.chunks_exact(4) {
406        regions.push(RegionMeta {
407            src_offset: chunk[0],
408            src_len: chunk[1],
409            dst_offset: chunk[2],
410            dst_len: chunk[3],
411        });
412    }
413    Ok(regions)
414}
415
416/// `zeroed_storage` function.
417pub fn zeroed_storage(device: &wgpu::Device, label: &str, bytes: usize) -> Result<PooledBuffer> {
418    BufferPool::global().acquire(
419        device,
420        label,
421        u64::try_from(bytes).map_err(|source| Error::Decode {
422            message: format!(
423                "decode zeroed storage size {bytes} cannot fit u64: {source}. Fix: split the decode input before GPU dispatch."
424            ),
425        })?,
426        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
427    )
428}
429
430/// `readback_buffer` function.
431pub fn readback_buffer(
432    device: &wgpu::Device,
433    encoder: &mut wgpu::CommandEncoder,
434    source: &wgpu::Buffer,
435    size: u64,
436) -> Result<PooledBuffer> {
437    let readback = BufferPool::global().acquire(
438        device,
439        "vyre decode readback",
440        size,
441        wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
442    )?;
443    encoder.copy_buffer_to_buffer(source, 0, &readback, 0, size);
444    Ok(readback)
445}