1use thiserror::Error;
9
10pub const WBIN_MAGIC: &[u8; 4] = b"WAVE";
11pub const WBIN_VERSION: u16 = 0x0001;
12pub const WBIN_HEADER_SIZE: usize = 32;
13pub const KERNEL_METADATA_SIZE: usize = 32;
14
15#[derive(Debug, Error)]
16pub enum WbinError {
17 #[error("file too small: expected at least {expected} bytes, got {actual}")]
18 FileTooSmall { expected: usize, actual: usize },
19
20 #[error("invalid magic: expected 'WAVE', got {actual:?}")]
21 InvalidMagic { actual: [u8; 4] },
22
23 #[error("unsupported version: {version}")]
24 UnsupportedVersion { version: u16 },
25
26 #[error("invalid offset: {field} offset {offset} exceeds file size {file_size}")]
27 InvalidOffset {
28 field: &'static str,
29 offset: u32,
30 file_size: usize,
31 },
32
33 #[error(
34 "invalid size: {field} at offset {offset} with size {size} exceeds file size {file_size}"
35 )]
36 InvalidSize {
37 field: &'static str,
38 offset: u32,
39 size: u32,
40 file_size: usize,
41 },
42
43 #[error("unterminated string in symbol table at offset {offset}")]
44 UnterminatedString { offset: u32 },
45
46 #[error("invalid kernel count: {count} kernels but metadata size is {metadata_size}")]
47 InvalidKernelCount { count: u32, metadata_size: u32 },
48}
49
50#[derive(Debug, Clone)]
51pub struct WbinHeader {
52 pub version: u16,
53 pub flags: u16,
54 pub code_offset: u32,
55 pub code_size: u32,
56 pub symbol_offset: u32,
57 pub symbol_size: u32,
58 pub metadata_offset: u32,
59 pub metadata_size: u32,
60}
61
62#[derive(Debug, Clone)]
63pub struct KernelInfo {
64 pub name: String,
65 pub register_count: u32,
66 pub local_memory_size: u32,
67 pub workgroup_size: [u32; 3],
68 pub code_offset: u32,
69 pub code_size: u32,
70}
71
72#[derive(Debug, Clone)]
73pub struct WbinFile<'a> {
74 data: &'a [u8],
75 pub header: WbinHeader,
76 pub kernels: Vec<KernelInfo>,
77}
78
79impl<'a> WbinFile<'a> {
80 #[allow(clippy::missing_panics_doc)]
84 pub fn parse(data: &'a [u8]) -> Result<Self, WbinError> {
85 if data.len() < WBIN_HEADER_SIZE {
86 return Err(WbinError::FileTooSmall {
87 expected: WBIN_HEADER_SIZE,
88 actual: data.len(),
89 });
90 }
91
92 let magic: [u8; 4] = data[0..4].try_into().unwrap();
93 if &magic != WBIN_MAGIC {
94 return Err(WbinError::InvalidMagic { actual: magic });
95 }
96
97 let version = u16::from_le_bytes([data[4], data[5]]);
98 if version != WBIN_VERSION {
99 return Err(WbinError::UnsupportedVersion { version });
100 }
101
102 let flags = u16::from_le_bytes([data[6], data[7]]);
103 let code_offset = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
104 let code_size = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
105 let symbol_offset = u32::from_le_bytes([data[16], data[17], data[18], data[19]]);
106 let symbol_size = u32::from_le_bytes([data[20], data[21], data[22], data[23]]);
107 let metadata_offset = u32::from_le_bytes([data[24], data[25], data[26], data[27]]);
108 let metadata_size = u32::from_le_bytes([data[28], data[29], data[30], data[31]]);
109
110 Self::validate_section(data.len(), "code", code_offset, code_size)?;
111 if symbol_offset != 0 || symbol_size != 0 {
112 Self::validate_section(data.len(), "symbol", symbol_offset, symbol_size)?;
113 }
114 Self::validate_section(data.len(), "metadata", metadata_offset, metadata_size)?;
115
116 let header = WbinHeader {
117 version,
118 flags,
119 code_offset,
120 code_size,
121 symbol_offset,
122 symbol_size,
123 metadata_offset,
124 metadata_size,
125 };
126
127 let kernels = Self::parse_kernels(data, &header)?;
128
129 Ok(Self {
130 data,
131 header,
132 kernels,
133 })
134 }
135
136 fn validate_section(
137 file_size: usize,
138 field: &'static str,
139 offset: u32,
140 size: u32,
141 ) -> Result<(), WbinError> {
142 if offset as usize > file_size {
143 return Err(WbinError::InvalidOffset {
144 field,
145 offset,
146 file_size,
147 });
148 }
149 if (offset as usize + size as usize) > file_size {
150 return Err(WbinError::InvalidSize {
151 field,
152 offset,
153 size,
154 file_size,
155 });
156 }
157 Ok(())
158 }
159
160 fn parse_kernels(data: &[u8], header: &WbinHeader) -> Result<Vec<KernelInfo>, WbinError> {
161 if header.metadata_size < 4 {
162 return Ok(Vec::new());
163 }
164
165 let meta_start = header.metadata_offset as usize;
166 let kernel_count = u32::from_le_bytes([
167 data[meta_start],
168 data[meta_start + 1],
169 data[meta_start + 2],
170 data[meta_start + 3],
171 ]);
172
173 #[allow(clippy::cast_possible_truncation)]
174 let expected_size = 4 + kernel_count * (KERNEL_METADATA_SIZE as u32);
175 if header.metadata_size < expected_size {
176 return Err(WbinError::InvalidKernelCount {
177 count: kernel_count,
178 metadata_size: header.metadata_size,
179 });
180 }
181
182 let mut kernels = Vec::with_capacity(kernel_count as usize);
183 let mut offset = meta_start + 4;
184
185 for _ in 0..kernel_count {
186 let name_offset = u32::from_le_bytes([
187 data[offset],
188 data[offset + 1],
189 data[offset + 2],
190 data[offset + 3],
191 ]);
192 let register_count = u32::from_le_bytes([
193 data[offset + 4],
194 data[offset + 5],
195 data[offset + 6],
196 data[offset + 7],
197 ]);
198 let local_memory_size = u32::from_le_bytes([
199 data[offset + 8],
200 data[offset + 9],
201 data[offset + 10],
202 data[offset + 11],
203 ]);
204 let ws_x = u32::from_le_bytes([
205 data[offset + 12],
206 data[offset + 13],
207 data[offset + 14],
208 data[offset + 15],
209 ]);
210 let ws_y = u32::from_le_bytes([
211 data[offset + 16],
212 data[offset + 17],
213 data[offset + 18],
214 data[offset + 19],
215 ]);
216 let ws_z = u32::from_le_bytes([
217 data[offset + 20],
218 data[offset + 21],
219 data[offset + 22],
220 data[offset + 23],
221 ]);
222 let code_offset = u32::from_le_bytes([
223 data[offset + 24],
224 data[offset + 25],
225 data[offset + 26],
226 data[offset + 27],
227 ]);
228 let code_size = u32::from_le_bytes([
229 data[offset + 28],
230 data[offset + 29],
231 data[offset + 30],
232 data[offset + 31],
233 ]);
234
235 let name = if name_offset != 0 && header.symbol_size > 0 {
236 Self::read_string(data, name_offset as usize)?
237 } else {
238 String::new()
239 };
240
241 kernels.push(KernelInfo {
242 name,
243 register_count,
244 local_memory_size,
245 workgroup_size: [ws_x, ws_y, ws_z],
246 code_offset,
247 code_size,
248 });
249
250 offset += KERNEL_METADATA_SIZE;
251 }
252
253 Ok(kernels)
254 }
255
256 fn read_string(data: &[u8], offset: usize) -> Result<String, WbinError> {
257 let mut end = offset;
258 while end < data.len() && data[end] != 0 {
259 end += 1;
260 }
261 if end >= data.len() {
262 #[allow(clippy::cast_possible_truncation)]
263 return Err(WbinError::UnterminatedString {
264 offset: offset as u32,
265 });
266 }
267 Ok(String::from_utf8_lossy(&data[offset..end]).to_string())
268 }
269
270 #[must_use]
271 pub fn code(&self) -> &[u8] {
272 let start = self.header.code_offset as usize;
273 let end = start + self.header.code_size as usize;
274 &self.data[start..end]
275 }
276
277 #[must_use]
278 pub fn kernel_code(&self, kernel_index: usize) -> Option<&[u8]> {
279 let kernel = self.kernels.get(kernel_index)?;
280 let start = self.header.code_offset as usize + kernel.code_offset as usize;
281 let end = start + kernel.code_size as usize;
282 if end <= self.data.len() {
283 Some(&self.data[start..end])
284 } else {
285 None
286 }
287 }
288
289 #[must_use]
290 pub fn find_kernel(&self, name: &str) -> Option<&KernelInfo> {
291 self.kernels.iter().find(|k| k.name == name)
292 }
293
294 #[must_use]
295 pub fn has_symbols(&self) -> bool {
296 self.header.symbol_size > 0
297 }
298
299 #[must_use]
300 pub fn kernel_count(&self) -> usize {
301 self.kernels.len()
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 fn make_minimal_wbin() -> Vec<u8> {
310 let mut data = Vec::new();
311 data.extend_from_slice(b"WAVE");
312 data.extend_from_slice(&WBIN_VERSION.to_le_bytes());
313 data.extend_from_slice(&0u16.to_le_bytes());
314 data.extend_from_slice(&0x20u32.to_le_bytes());
315 data.extend_from_slice(&4u32.to_le_bytes());
316 data.extend_from_slice(&0u32.to_le_bytes());
317 data.extend_from_slice(&0u32.to_le_bytes());
318 data.extend_from_slice(&0x24u32.to_le_bytes());
319 data.extend_from_slice(&4u32.to_le_bytes());
320 data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
321 data.extend_from_slice(&0u32.to_le_bytes());
322 data
323 }
324
325 fn make_wbin_with_kernel() -> Vec<u8> {
326 let mut data = Vec::new();
327 let kernel_name = b"test_kernel\0";
328 let code_bytes: [u8; 8] = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
329
330 let code_offset: u32 = 0x20;
331 let code_size: u32 = 8;
332 let symbol_offset: u32 = code_offset + code_size;
333 let symbol_size: u32 = kernel_name.len() as u32;
334 let metadata_offset: u32 = symbol_offset + symbol_size;
335 let metadata_size: u32 = 4 + 32;
336
337 data.extend_from_slice(b"WAVE");
338 data.extend_from_slice(&WBIN_VERSION.to_le_bytes());
339 data.extend_from_slice(&0u16.to_le_bytes());
340 data.extend_from_slice(&code_offset.to_le_bytes());
341 data.extend_from_slice(&code_size.to_le_bytes());
342 data.extend_from_slice(&symbol_offset.to_le_bytes());
343 data.extend_from_slice(&symbol_size.to_le_bytes());
344 data.extend_from_slice(&metadata_offset.to_le_bytes());
345 data.extend_from_slice(&metadata_size.to_le_bytes());
346 data.extend_from_slice(&code_bytes);
347 data.extend_from_slice(kernel_name);
348 data.extend_from_slice(&1u32.to_le_bytes());
349 data.extend_from_slice(&symbol_offset.to_le_bytes());
350 data.extend_from_slice(&16u32.to_le_bytes());
351 data.extend_from_slice(&1024u32.to_le_bytes());
352 data.extend_from_slice(&256u32.to_le_bytes());
353 data.extend_from_slice(&1u32.to_le_bytes());
354 data.extend_from_slice(&1u32.to_le_bytes());
355 data.extend_from_slice(&0u32.to_le_bytes());
356 data.extend_from_slice(&8u32.to_le_bytes());
357 data
358 }
359
360 #[test]
361 fn test_parse_minimal() {
362 let data = make_minimal_wbin();
363 let wbin = WbinFile::parse(&data).unwrap();
364
365 assert_eq!(wbin.header.version, WBIN_VERSION);
366 assert_eq!(wbin.header.code_size, 4);
367 assert_eq!(wbin.kernels.len(), 0);
368 }
369
370 #[test]
371 fn test_parse_with_kernel() {
372 let data = make_wbin_with_kernel();
373 let wbin = WbinFile::parse(&data).unwrap();
374
375 assert_eq!(wbin.kernels.len(), 1);
376 assert_eq!(wbin.kernels[0].name, "test_kernel");
377 assert_eq!(wbin.kernels[0].register_count, 16);
378 assert_eq!(wbin.kernels[0].local_memory_size, 1024);
379 assert_eq!(wbin.kernels[0].workgroup_size, [256, 1, 1]);
380 }
381
382 #[test]
383 fn test_invalid_magic() {
384 let mut data = make_minimal_wbin();
385 data[0] = b'X';
386
387 let err = WbinFile::parse(&data).unwrap_err();
388 assert!(matches!(err, WbinError::InvalidMagic { .. }));
389 }
390
391 #[test]
392 fn test_file_too_small() {
393 let data = vec![0u8; 16];
394 let err = WbinFile::parse(&data).unwrap_err();
395 assert!(matches!(err, WbinError::FileTooSmall { .. }));
396 }
397
398 #[test]
399 fn test_get_code() {
400 let data = make_wbin_with_kernel();
401 let wbin = WbinFile::parse(&data).unwrap();
402
403 let code = wbin.code();
404 assert_eq!(code.len(), 8);
405 assert_eq!(code, &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
406 }
407
408 #[test]
409 fn test_kernel_code() {
410 let data = make_wbin_with_kernel();
411 let wbin = WbinFile::parse(&data).unwrap();
412
413 let code = wbin.kernel_code(0).unwrap();
414 assert_eq!(code.len(), 8);
415 }
416
417 #[test]
418 fn test_find_kernel() {
419 let data = make_wbin_with_kernel();
420 let wbin = WbinFile::parse(&data).unwrap();
421
422 let kernel = wbin.find_kernel("test_kernel").unwrap();
423 assert_eq!(kernel.register_count, 16);
424
425 assert!(wbin.find_kernel("nonexistent").is_none());
426 }
427
428 #[test]
429 fn test_has_symbols() {
430 let data_with_symbols = make_wbin_with_kernel();
431 let wbin_with = WbinFile::parse(&data_with_symbols).unwrap();
432 assert!(wbin_with.has_symbols());
433
434 let data_stripped = make_minimal_wbin();
435 let wbin_stripped = WbinFile::parse(&data_stripped).unwrap();
436 assert!(!wbin_stripped.has_symbols());
437 }
438}