Skip to main content

vyre_reference/
oob.rs

1//! Out-of-bounds rules enforced by the parity engine.
2//!
3//! GPU drivers differ on what happens when a shader indexes past the end of a
4//! buffer: some clamp, some return zero, some crash. The reference interpreter
5//! eliminates that ambiguity by defining one deterministic behavior — defined-type
6//! zero-fill for scalar loads, empty slice for `Bytes`, and silent no-op for stores.
7//! Any backend that diverges from these rules fails the conform gate.
8
9use vyre::ir::DataType as IrDataType;
10
11use crate::value::Value;
12use vyre::ir::DataType;
13
14/// Typed bytes backing one declared IR buffer.
15///
16/// This struct exists to give the reference interpreter a single place to enforce
17/// stride-correct indexing and OOB semantics, independent of how any GPU driver
18/// handles buffer bounds.
19#[derive(Debug, Clone)]
20pub struct Buffer {
21    pub(crate) bytes: Vec<u8>,
22    pub(crate) element: IrDataType,
23}
24
25impl Buffer {
26    pub(crate) fn len(&self) -> u32 {
27        // The engine indexes buffers with u32 (WGSL-native). Lengths beyond
28        // u32::MAX are unreachable through the IR's index space, so saturate
29        // rather than truncating — anything past u32::MAX maps to the OOB
30        // tail where loads zero-fill and stores no-op.
31        let stride = self.element.min_bytes();
32        let count = if stride == 0 {
33            self.bytes.len()
34        } else {
35            self.bytes.len() / stride
36        };
37        u32::try_from(count).unwrap_or(u32::MAX)
38    }
39}
40
41pub(crate) fn load(buffer: &Buffer, index: u32) -> Value {
42    let stride = buffer.element.min_bytes();
43    let ty = ir_to_conform_type(buffer.element.clone());
44    if matches!(buffer.element, IrDataType::Bytes) {
45        let offset = index as usize;
46        if offset > buffer.bytes.len() {
47            return Value::Bytes(Vec::new());
48        }
49        return Value::Bytes(buffer.bytes[offset..].to_vec());
50    }
51    let Some(offset) = byte_offset(index, stride) else {
52        return Value::try_zero_for(ty).unwrap_or(Value::Bytes(Vec::new()));
53    };
54    if stride == 0 || offset + stride > buffer.bytes.len() {
55        return Value::try_zero_for(ty).unwrap_or(Value::Bytes(Vec::new()));
56    }
57    Value::from_element_bytes(ty.clone(), &buffer.bytes[offset..offset + stride])
58        .unwrap_or_else(|_| Value::try_zero_for(ty).unwrap_or(Value::Bytes(Vec::new())))
59}
60
61pub(crate) fn store(buffer: &mut Buffer, index: u32, value: &Value) {
62    let stride = buffer.element.min_bytes();
63    if matches!(buffer.element, IrDataType::Bytes) {
64        let offset = index as usize;
65        if offset >= buffer.bytes.len() {
66            return; // OOB: silent no-op, matching GPU buffer semantics.
67        }
68        let bytes = value.to_bytes();
69        // Clamp to the remaining buffer capacity — GPU buffers cannot grow.
70        let available = buffer.bytes.len() - offset;
71        let write_len = bytes.len().min(available);
72        buffer.bytes[offset..offset + write_len].copy_from_slice(&bytes[..write_len]);
73        return;
74    }
75    let Some(offset) = byte_offset(index, stride) else {
76        return;
77    };
78    if stride == 0 || offset + stride > buffer.bytes.len() {
79        return;
80    }
81    write_element(
82        buffer.element.clone(),
83        &mut buffer.bytes[offset..offset + stride],
84        value,
85    );
86}
87
88pub(crate) fn atomic_load(buffer: &Buffer, index: u32) -> Option<u32> {
89    // Kimi audit finding #7: index must be scaled by the buffer's
90    // declared element stride, not hardcoded to 4. A U64 buffer at
91    // index 1 sits at byte offset 8, not 4. The previous hardcode
92    // caused atomic ops on wider-than-u32 elements to overlap every
93    // pair of elements and corrupt the reference semantics.
94    let stride = buffer.element.min_bytes().max(4);
95    let offset = byte_offset(index, stride)?;
96    if offset + 4 > buffer.bytes.len() {
97        None
98    } else {
99        Some(read_u32(&buffer.bytes[offset..offset + 4]))
100    }
101}
102
103pub(crate) fn atomic_store(buffer: &mut Buffer, index: u32, value: u32) {
104    // See atomic_load — stride must come from the element type, not
105    // be hardcoded.
106    let stride = buffer.element.min_bytes().max(4);
107    let Some(offset) = byte_offset(index, stride) else {
108        return;
109    };
110    if offset + 4 <= buffer.bytes.len() {
111        write_u32(&mut buffer.bytes[offset..offset + 4], value);
112    }
113}
114
115fn byte_offset(index: u32, stride: usize) -> Option<usize> {
116    (index as usize).checked_mul(stride)
117}
118
119fn write_element(element: IrDataType, target: &mut [u8], value: &Value) {
120    match element {
121        IrDataType::U32 => {
122            target.copy_from_slice(&value.to_bytes_width(4)[..4]);
123        }
124        IrDataType::I32 => {
125            target.copy_from_slice(&value.to_bytes_width(4)[..4]);
126        }
127        IrDataType::Bool => {
128            target.copy_from_slice(&value.to_bytes_width(4)[..4]);
129        }
130        IrDataType::U64 => {
131            let bytes = value.to_bytes_width(8);
132            target.copy_from_slice(&bytes[..8]);
133        }
134        IrDataType::F32 => {
135            // Value::Float carries an f64; the GPU buffer is four bytes
136            // of f32, so narrow via `as f32` before writing. Dropping the
137            // upper four bytes of `v.to_le_bytes()` (what the default
138            // to_bytes_width path does) would mangle the f32 bit pattern.
139            let v = match value {
140                Value::Float(v) => *v as f32,
141                Value::U32(v) => f32::from_bits(*v),
142                _ => 0.0,
143            };
144            target.copy_from_slice(&v.to_le_bytes());
145        }
146        IrDataType::Bytes | IrDataType::Vec2U32 | IrDataType::Vec4U32 => {
147            let bytes = value.to_bytes_width(target.len());
148            let len = target.len().min(bytes.len());
149            target[..len].copy_from_slice(&bytes[..len]);
150            target[len..].fill(0);
151        }
152        _ => {
153            let bytes = value.to_bytes_width(target.len());
154            let len = target.len().min(bytes.len());
155            target[..len].copy_from_slice(&bytes[..len]);
156            target[len..].fill(0);
157        }
158    }
159}
160
161fn read_u32(bytes: &[u8]) -> u32 {
162    u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
163}
164
165fn write_u32(bytes: &mut [u8], value: u32) {
166    bytes.copy_from_slice(&value.to_le_bytes());
167}
168
169fn ir_to_conform_type(ty: IrDataType) -> DataType {
170    match ty {
171        IrDataType::U32 => DataType::U32,
172        IrDataType::I32 => DataType::I32,
173        IrDataType::U64 => DataType::U64,
174        IrDataType::F32 => DataType::F32,
175        IrDataType::Vec2U32 => DataType::Vec2U32,
176        IrDataType::Vec4U32 => DataType::Vec4U32,
177        IrDataType::Bool => DataType::U32,
178        IrDataType::Bytes => DataType::Bytes,
179        _ => DataType::Bytes,
180    }
181}