Skip to main content

vyre_reference/
oob.rs

1//! Out-of-bounds rules enforced by the parity engine.
2//!
3//! GPU drivers differ on what happens when a shader indexes past the end of a
4//! buffer: some clamp, some return zero, some crash. The reference interpreter
5//! eliminates that ambiguity by defining one deterministic behavior — defined-type
6//! zero-fill for scalar loads, empty slice for `Bytes`, and silent no-op for stores.
7//! Any backend that diverges from these rules fails the conform gate.
8
9use vyre::ir::DataType as IrDataType;
10
11use crate::value::Value;
12use vyre::ir::DataType;
13
14use std::sync::{Arc, RwLock};
15
16/// Typed bytes backing one declared IR buffer.
17///
18/// This struct exists to give the reference interpreter a single place to enforce
19/// stride-correct indexing and OOB semantics, independent of how any GPU driver
20/// handles buffer bounds.
21#[derive(Debug, Clone)]
22pub struct Buffer {
23    pub(crate) bytes: Arc<RwLock<Vec<u8>>>,
24    pub(crate) element: IrDataType,
25}
26
27impl Buffer {
28    /// Create a buffer from typed bytes.
29    #[must_use]
30    pub fn new(bytes: Vec<u8>, element: DataType) -> Self {
31        Self {
32            bytes: Arc::new(RwLock::new(bytes)),
33            element,
34        }
35    }
36
37    pub(crate) fn len(&self) -> u32 {
38        let bytes_guard = self.bytes.read().unwrap_or_else(|error| error.into_inner());
39        let stride = self.element.min_bytes();
40        let count = if stride == 0 {
41            bytes_guard.len()
42        } else {
43            bytes_guard.len() / stride
44        };
45        u32::try_from(count).unwrap_or(u32::MAX)
46    }
47
48    pub(crate) fn byte_len(&self) -> usize {
49        self.bytes
50            .read()
51            .unwrap_or_else(|error| error.into_inner())
52            .len()
53    }
54
55    pub(crate) fn element(&self) -> &IrDataType {
56        &self.element
57    }
58
59    pub(crate) fn zero_fill(&self) {
60        self.bytes
61            .write()
62            .unwrap_or_else(|error| error.into_inner())
63            .fill(0);
64    }
65
66    /// Consume this buffer and return its contents as a Value.
67    #[must_use]
68    pub fn to_value(self) -> crate::value::Value {
69        let vec = std::sync::Arc::try_unwrap(self.bytes)
70            .map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
71            .unwrap_or_else(|a| a.read().unwrap_or_else(|error| error.into_inner()).clone());
72        crate::value::Value::from(vec)
73    }
74}
75
76pub(crate) fn load(buffer: &Buffer, index: u32) -> Value {
77    let bytes_guard = buffer
78        .bytes
79        .read()
80        .unwrap_or_else(|error| error.into_inner());
81    let stride = buffer.element.min_bytes();
82    let ty = ir_to_conform_type(buffer.element.clone());
83    if matches!(buffer.element, IrDataType::Bytes) {
84        let offset = index as usize;
85        if offset > bytes_guard.len() {
86            return Value::from(Vec::new());
87        }
88        return Value::from(&bytes_guard[offset..]);
89    }
90    let Some(offset) = byte_offset(index, stride) else {
91        return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
92    };
93    if stride == 0 || offset + stride > bytes_guard.len() {
94        return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
95    }
96    Value::from_element_bytes(ty.clone(), &bytes_guard[offset..offset + stride])
97        .unwrap_or_else(|_| Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new())))
98}
99
100pub(crate) fn store(buffer: &mut Buffer, index: u32, value: &Value) {
101    let mut bytes_guard = buffer
102        .bytes
103        .write()
104        .unwrap_or_else(|error| error.into_inner());
105    let stride = buffer.element.min_bytes();
106    if matches!(buffer.element, IrDataType::Bytes) {
107        let offset = index as usize;
108        if offset >= bytes_guard.len() {
109            return;
110        }
111        let bytes = value.to_bytes();
112        let available = bytes_guard.len() - offset;
113        let write_len = bytes.len().min(available);
114        bytes_guard[offset..offset + write_len].copy_from_slice(&bytes[..write_len]);
115        return;
116    }
117    let Some(offset) = byte_offset(index, stride) else {
118        return;
119    };
120    if stride == 0 || offset + stride > bytes_guard.len() {
121        return;
122    }
123    write_element(
124        buffer.element.clone(),
125        &mut bytes_guard[offset..offset + stride],
126        value,
127    );
128}
129
130pub(crate) fn atomic_load(buffer: &Buffer, index: u32) -> Option<u32> {
131    let bytes_guard = buffer
132        .bytes
133        .read()
134        .unwrap_or_else(|error| error.into_inner());
135    let stride = buffer.element.min_bytes().max(4);
136    let offset = byte_offset(index, stride)?;
137    if offset + 4 > bytes_guard.len() {
138        None
139    } else {
140        Some(read_u32(&bytes_guard[offset..offset + 4]))
141    }
142}
143
144pub(crate) fn atomic_store(buffer: &mut Buffer, index: u32, value: u32) {
145    let mut bytes_guard = buffer
146        .bytes
147        .write()
148        .unwrap_or_else(|error| error.into_inner());
149    let stride = buffer.element.min_bytes().max(4);
150    let Some(offset) = byte_offset(index, stride) else {
151        return;
152    };
153    if offset + 4 <= bytes_guard.len() {
154        write_u32(&mut bytes_guard[offset..offset + 4], value);
155    }
156}
157
158fn byte_offset(index: u32, stride: usize) -> Option<usize> {
159    (index as usize).checked_mul(stride)
160}
161
162fn write_element(element: IrDataType, target: &mut [u8], value: &Value) {
163    match element {
164        IrDataType::U32 => {
165            target.copy_from_slice(&value.to_bytes_width(4)[..4]);
166        }
167        IrDataType::I32 => {
168            target.copy_from_slice(&value.to_bytes_width(4)[..4]);
169        }
170        IrDataType::Bool => {
171            target.copy_from_slice(&value.to_bytes_width(4)[..4]);
172        }
173        IrDataType::U64 => {
174            let bytes = value.to_bytes_width(8);
175            target.copy_from_slice(&bytes[..8]);
176        }
177        IrDataType::F32 => {
178            // Value::Float carries an f64; the GPU buffer is four bytes
179            // of f32, so narrow via `as f32` before writing. Dropping the
180            // upper four bytes of `v.to_le_bytes()` (what the default
181            // to_bytes_width path does) would mangle the f32 bit pattern.
182            let v = match value {
183                Value::Float(v) => *v as f32,
184                Value::U32(v) => f32::from_bits(*v),
185                _ => 0.0,
186            };
187            target.copy_from_slice(&v.to_le_bytes());
188        }
189        IrDataType::Bytes | IrDataType::Vec2U32 | IrDataType::Vec4U32 => {
190            let bytes = value.to_bytes_width(target.len());
191            let len = target.len().min(bytes.len());
192            target[..len].copy_from_slice(&bytes[..len]);
193            target[len..].fill(0);
194        }
195        _ => {
196            let bytes = value.to_bytes_width(target.len());
197            let len = target.len().min(bytes.len());
198            target[..len].copy_from_slice(&bytes[..len]);
199            target[len..].fill(0);
200        }
201    }
202}
203
204fn read_u32(bytes: &[u8]) -> u32 {
205    u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
206}
207
208fn write_u32(bytes: &mut [u8], value: u32) {
209    bytes.copy_from_slice(&value.to_le_bytes());
210}
211
212fn ir_to_conform_type(ty: IrDataType) -> DataType {
213    match ty {
214        IrDataType::U32 => DataType::U32,
215        IrDataType::I32 => DataType::I32,
216        IrDataType::U64 => DataType::U64,
217        IrDataType::F32 => DataType::F32,
218        IrDataType::Vec2U32 => DataType::Vec2U32,
219        IrDataType::Vec4U32 => DataType::Vec4U32,
220        IrDataType::Bool => DataType::U32,
221        IrDataType::Bytes => DataType::Bytes,
222        _ => DataType::Bytes,
223    }
224}