Skip to main content

wave_decode/
wbin.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! WBIN container reader. Parses the .wbin binary format produced by wave-asm,
5//!
6//! extracting header information, code sections, symbols, and kernel metadata.
7
8use thiserror::Error;
9
10pub const WBIN_MAGIC: &[u8; 4] = b"WAVE";
11pub const WBIN_VERSION: u16 = 0x0001;
12pub const WBIN_HEADER_SIZE: usize = 32;
13pub const KERNEL_METADATA_SIZE: usize = 32;
14
15#[derive(Debug, Error)]
16pub enum WbinError {
17    #[error("file too small: expected at least {expected} bytes, got {actual}")]
18    FileTooSmall { expected: usize, actual: usize },
19
20    #[error("invalid magic: expected 'WAVE', got {actual:?}")]
21    InvalidMagic { actual: [u8; 4] },
22
23    #[error("unsupported version: {version}")]
24    UnsupportedVersion { version: u16 },
25
26    #[error("invalid offset: {field} offset {offset} exceeds file size {file_size}")]
27    InvalidOffset {
28        field: &'static str,
29        offset: u32,
30        file_size: usize,
31    },
32
33    #[error(
34        "invalid size: {field} at offset {offset} with size {size} exceeds file size {file_size}"
35    )]
36    InvalidSize {
37        field: &'static str,
38        offset: u32,
39        size: u32,
40        file_size: usize,
41    },
42
43    #[error("unterminated string in symbol table at offset {offset}")]
44    UnterminatedString { offset: u32 },
45
46    #[error("invalid kernel count: {count} kernels but metadata size is {metadata_size}")]
47    InvalidKernelCount { count: u32, metadata_size: u32 },
48}
49
50#[derive(Debug, Clone)]
51pub struct WbinHeader {
52    pub version: u16,
53    pub flags: u16,
54    pub code_offset: u32,
55    pub code_size: u32,
56    pub symbol_offset: u32,
57    pub symbol_size: u32,
58    pub metadata_offset: u32,
59    pub metadata_size: u32,
60}
61
62#[derive(Debug, Clone)]
63pub struct KernelInfo {
64    pub name: String,
65    pub register_count: u32,
66    pub local_memory_size: u32,
67    pub workgroup_size: [u32; 3],
68    pub code_offset: u32,
69    pub code_size: u32,
70}
71
72#[derive(Debug, Clone)]
73pub struct WbinFile<'a> {
74    data: &'a [u8],
75    pub header: WbinHeader,
76    pub kernels: Vec<KernelInfo>,
77}
78
79impl<'a> WbinFile<'a> {
80    /// # Errors
81    ///
82    /// Returns `WbinError` if the binary is malformed or too small.
83    #[allow(clippy::missing_panics_doc)]
84    pub fn parse(data: &'a [u8]) -> Result<Self, WbinError> {
85        if data.len() < WBIN_HEADER_SIZE {
86            return Err(WbinError::FileTooSmall {
87                expected: WBIN_HEADER_SIZE,
88                actual: data.len(),
89            });
90        }
91
92        let magic: [u8; 4] = data[0..4].try_into().unwrap();
93        if &magic != WBIN_MAGIC {
94            return Err(WbinError::InvalidMagic { actual: magic });
95        }
96
97        let version = u16::from_le_bytes([data[4], data[5]]);
98        if version != WBIN_VERSION {
99            return Err(WbinError::UnsupportedVersion { version });
100        }
101
102        let flags = u16::from_le_bytes([data[6], data[7]]);
103        let code_offset = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
104        let code_size = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
105        let symbol_offset = u32::from_le_bytes([data[16], data[17], data[18], data[19]]);
106        let symbol_size = u32::from_le_bytes([data[20], data[21], data[22], data[23]]);
107        let metadata_offset = u32::from_le_bytes([data[24], data[25], data[26], data[27]]);
108        let metadata_size = u32::from_le_bytes([data[28], data[29], data[30], data[31]]);
109
110        Self::validate_section(data.len(), "code", code_offset, code_size)?;
111        if symbol_offset != 0 || symbol_size != 0 {
112            Self::validate_section(data.len(), "symbol", symbol_offset, symbol_size)?;
113        }
114        Self::validate_section(data.len(), "metadata", metadata_offset, metadata_size)?;
115
116        let header = WbinHeader {
117            version,
118            flags,
119            code_offset,
120            code_size,
121            symbol_offset,
122            symbol_size,
123            metadata_offset,
124            metadata_size,
125        };
126
127        let kernels = Self::parse_kernels(data, &header)?;
128
129        Ok(Self {
130            data,
131            header,
132            kernels,
133        })
134    }
135
136    fn validate_section(
137        file_size: usize,
138        field: &'static str,
139        offset: u32,
140        size: u32,
141    ) -> Result<(), WbinError> {
142        if offset as usize > file_size {
143            return Err(WbinError::InvalidOffset {
144                field,
145                offset,
146                file_size,
147            });
148        }
149        if (offset as usize + size as usize) > file_size {
150            return Err(WbinError::InvalidSize {
151                field,
152                offset,
153                size,
154                file_size,
155            });
156        }
157        Ok(())
158    }
159
160    fn parse_kernels(data: &[u8], header: &WbinHeader) -> Result<Vec<KernelInfo>, WbinError> {
161        if header.metadata_size < 4 {
162            return Ok(Vec::new());
163        }
164
165        let meta_start = header.metadata_offset as usize;
166        let kernel_count = u32::from_le_bytes([
167            data[meta_start],
168            data[meta_start + 1],
169            data[meta_start + 2],
170            data[meta_start + 3],
171        ]);
172
173        #[allow(clippy::cast_possible_truncation)]
174        let expected_size = 4 + kernel_count * (KERNEL_METADATA_SIZE as u32);
175        if header.metadata_size < expected_size {
176            return Err(WbinError::InvalidKernelCount {
177                count: kernel_count,
178                metadata_size: header.metadata_size,
179            });
180        }
181
182        let mut kernels = Vec::with_capacity(kernel_count as usize);
183        let mut offset = meta_start + 4;
184
185        for _ in 0..kernel_count {
186            let name_offset = u32::from_le_bytes([
187                data[offset],
188                data[offset + 1],
189                data[offset + 2],
190                data[offset + 3],
191            ]);
192            let register_count = u32::from_le_bytes([
193                data[offset + 4],
194                data[offset + 5],
195                data[offset + 6],
196                data[offset + 7],
197            ]);
198            let local_memory_size = u32::from_le_bytes([
199                data[offset + 8],
200                data[offset + 9],
201                data[offset + 10],
202                data[offset + 11],
203            ]);
204            let ws_x = u32::from_le_bytes([
205                data[offset + 12],
206                data[offset + 13],
207                data[offset + 14],
208                data[offset + 15],
209            ]);
210            let ws_y = u32::from_le_bytes([
211                data[offset + 16],
212                data[offset + 17],
213                data[offset + 18],
214                data[offset + 19],
215            ]);
216            let ws_z = u32::from_le_bytes([
217                data[offset + 20],
218                data[offset + 21],
219                data[offset + 22],
220                data[offset + 23],
221            ]);
222            let code_offset = u32::from_le_bytes([
223                data[offset + 24],
224                data[offset + 25],
225                data[offset + 26],
226                data[offset + 27],
227            ]);
228            let code_size = u32::from_le_bytes([
229                data[offset + 28],
230                data[offset + 29],
231                data[offset + 30],
232                data[offset + 31],
233            ]);
234
235            let name = if name_offset != 0 && header.symbol_size > 0 {
236                Self::read_string(data, name_offset as usize)?
237            } else {
238                String::new()
239            };
240
241            kernels.push(KernelInfo {
242                name,
243                register_count,
244                local_memory_size,
245                workgroup_size: [ws_x, ws_y, ws_z],
246                code_offset,
247                code_size,
248            });
249
250            offset += KERNEL_METADATA_SIZE;
251        }
252
253        Ok(kernels)
254    }
255
256    fn read_string(data: &[u8], offset: usize) -> Result<String, WbinError> {
257        let mut end = offset;
258        while end < data.len() && data[end] != 0 {
259            end += 1;
260        }
261        if end >= data.len() {
262            #[allow(clippy::cast_possible_truncation)]
263            return Err(WbinError::UnterminatedString {
264                offset: offset as u32,
265            });
266        }
267        Ok(String::from_utf8_lossy(&data[offset..end]).to_string())
268    }
269
270    #[must_use]
271    pub fn code(&self) -> &[u8] {
272        let start = self.header.code_offset as usize;
273        let end = start + self.header.code_size as usize;
274        &self.data[start..end]
275    }
276
277    #[must_use]
278    pub fn kernel_code(&self, kernel_index: usize) -> Option<&[u8]> {
279        let kernel = self.kernels.get(kernel_index)?;
280        let start = self.header.code_offset as usize + kernel.code_offset as usize;
281        let end = start + kernel.code_size as usize;
282        if end <= self.data.len() {
283            Some(&self.data[start..end])
284        } else {
285            None
286        }
287    }
288
289    #[must_use]
290    pub fn find_kernel(&self, name: &str) -> Option<&KernelInfo> {
291        self.kernels.iter().find(|k| k.name == name)
292    }
293
294    #[must_use]
295    pub fn has_symbols(&self) -> bool {
296        self.header.symbol_size > 0
297    }
298
299    #[must_use]
300    pub fn kernel_count(&self) -> usize {
301        self.kernels.len()
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    fn make_minimal_wbin() -> Vec<u8> {
310        let mut data = Vec::new();
311        data.extend_from_slice(b"WAVE");
312        data.extend_from_slice(&WBIN_VERSION.to_le_bytes());
313        data.extend_from_slice(&0u16.to_le_bytes());
314        data.extend_from_slice(&0x20u32.to_le_bytes());
315        data.extend_from_slice(&4u32.to_le_bytes());
316        data.extend_from_slice(&0u32.to_le_bytes());
317        data.extend_from_slice(&0u32.to_le_bytes());
318        data.extend_from_slice(&0x24u32.to_le_bytes());
319        data.extend_from_slice(&4u32.to_le_bytes());
320        data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
321        data.extend_from_slice(&0u32.to_le_bytes());
322        data
323    }
324
325    fn make_wbin_with_kernel() -> Vec<u8> {
326        let mut data = Vec::new();
327        let kernel_name = b"test_kernel\0";
328        let code_bytes: [u8; 8] = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
329
330        let code_offset: u32 = 0x20;
331        let code_size: u32 = 8;
332        let symbol_offset: u32 = code_offset + code_size;
333        let symbol_size: u32 = kernel_name.len() as u32;
334        let metadata_offset: u32 = symbol_offset + symbol_size;
335        let metadata_size: u32 = 4 + 32;
336
337        data.extend_from_slice(b"WAVE");
338        data.extend_from_slice(&WBIN_VERSION.to_le_bytes());
339        data.extend_from_slice(&0u16.to_le_bytes());
340        data.extend_from_slice(&code_offset.to_le_bytes());
341        data.extend_from_slice(&code_size.to_le_bytes());
342        data.extend_from_slice(&symbol_offset.to_le_bytes());
343        data.extend_from_slice(&symbol_size.to_le_bytes());
344        data.extend_from_slice(&metadata_offset.to_le_bytes());
345        data.extend_from_slice(&metadata_size.to_le_bytes());
346        data.extend_from_slice(&code_bytes);
347        data.extend_from_slice(kernel_name);
348        data.extend_from_slice(&1u32.to_le_bytes());
349        data.extend_from_slice(&symbol_offset.to_le_bytes());
350        data.extend_from_slice(&16u32.to_le_bytes());
351        data.extend_from_slice(&1024u32.to_le_bytes());
352        data.extend_from_slice(&256u32.to_le_bytes());
353        data.extend_from_slice(&1u32.to_le_bytes());
354        data.extend_from_slice(&1u32.to_le_bytes());
355        data.extend_from_slice(&0u32.to_le_bytes());
356        data.extend_from_slice(&8u32.to_le_bytes());
357        data
358    }
359
360    #[test]
361    fn test_parse_minimal() {
362        let data = make_minimal_wbin();
363        let wbin = WbinFile::parse(&data).unwrap();
364
365        assert_eq!(wbin.header.version, WBIN_VERSION);
366        assert_eq!(wbin.header.code_size, 4);
367        assert_eq!(wbin.kernels.len(), 0);
368    }
369
370    #[test]
371    fn test_parse_with_kernel() {
372        let data = make_wbin_with_kernel();
373        let wbin = WbinFile::parse(&data).unwrap();
374
375        assert_eq!(wbin.kernels.len(), 1);
376        assert_eq!(wbin.kernels[0].name, "test_kernel");
377        assert_eq!(wbin.kernels[0].register_count, 16);
378        assert_eq!(wbin.kernels[0].local_memory_size, 1024);
379        assert_eq!(wbin.kernels[0].workgroup_size, [256, 1, 1]);
380    }
381
382    #[test]
383    fn test_invalid_magic() {
384        let mut data = make_minimal_wbin();
385        data[0] = b'X';
386
387        let err = WbinFile::parse(&data).unwrap_err();
388        assert!(matches!(err, WbinError::InvalidMagic { .. }));
389    }
390
391    #[test]
392    fn test_file_too_small() {
393        let data = vec![0u8; 16];
394        let err = WbinFile::parse(&data).unwrap_err();
395        assert!(matches!(err, WbinError::FileTooSmall { .. }));
396    }
397
398    #[test]
399    fn test_get_code() {
400        let data = make_wbin_with_kernel();
401        let wbin = WbinFile::parse(&data).unwrap();
402
403        let code = wbin.code();
404        assert_eq!(code.len(), 8);
405        assert_eq!(code, &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
406    }
407
408    #[test]
409    fn test_kernel_code() {
410        let data = make_wbin_with_kernel();
411        let wbin = WbinFile::parse(&data).unwrap();
412
413        let code = wbin.kernel_code(0).unwrap();
414        assert_eq!(code.len(), 8);
415    }
416
417    #[test]
418    fn test_find_kernel() {
419        let data = make_wbin_with_kernel();
420        let wbin = WbinFile::parse(&data).unwrap();
421
422        let kernel = wbin.find_kernel("test_kernel").unwrap();
423        assert_eq!(kernel.register_count, 16);
424
425        assert!(wbin.find_kernel("nonexistent").is_none());
426    }
427
428    #[test]
429    fn test_has_symbols() {
430        let data_with_symbols = make_wbin_with_kernel();
431        let wbin_with = WbinFile::parse(&data_with_symbols).unwrap();
432        assert!(wbin_with.has_symbols());
433
434        let data_stripped = make_minimal_wbin();
435        let wbin_stripped = WbinFile::parse(&data_stripped).unwrap();
436        assert!(!wbin_stripped.has_symbols());
437    }
438}