Skip to main content

wolfram_serialize/wxf/
reader.rs

1//! Typed, pull-based WXF reader — sugar over a raw [`Reader`].
2//!
3//! Each WXF enum in [`crate::constants`] gets a reader that consumes its byte
4//! and does the `TryFrom` (failing if the byte isn't that enum). There is **no
5//! peek**: a token is read exactly once via [`WxfReader::read_expr_token`] and
6//! the caller dispatches on it, then reads the matching payload.
7//!
8//! Methods deal only in primitives and raw parts — higher-level value types
9//! (`Symbol`, `NumericArray`, …) are assembled by the consumer (`wolfram-expr`).
10
11use crate::constants::{ExpressionEnum, NumericArrayEnum, PackedArrayEnum};
12use crate::reader::Reader;
13use crate::Error;
14
15/// Typed WXF reader wrapping a raw byte [`Reader`].
16pub struct WxfReader<R> {
17    inner: R,
18}
19
20impl<'de, R: Reader<'de>> WxfReader<R> {
21    /// Wrap a raw reader. The reader is assumed to be positioned at the start of
22    /// the WXF payload (header already consumed — see [`crate::from_wxf`][fn@crate::from_wxf]).
23    pub fn new(inner: R) -> Self {
24        WxfReader { inner }
25    }
26
27    //---- raw passthrough ------------------------------------------------
28
29    /// Consume one raw byte.
30    pub fn read_byte(&mut self) -> Result<u8, Error> {
31        self.inner.read_byte()
32    }
33
34    /// Consume `n` raw bytes as a zero-copy, buffer-lifetime view.
35    pub fn read_bytes(&mut self, n: usize) -> Result<&'de [u8], Error> {
36        self.inner.read_bytes(n)
37    }
38
39    /// Read a WXF varint (LEB128, 7-bit groups, little-endian).
40    pub fn read_varint(&mut self) -> Result<u64, Error> {
41        let mut result: u64 = 0;
42        let mut shift: u32 = 0;
43        loop {
44            if shift >= 64 {
45                return Err(Error::invalid("varint exceeds 64 bits".into()));
46            }
47            let b = self.inner.read_byte()?;
48            // The 10th group sits at bit 63: only its low bit fits in a u64.
49            // Reject any higher bits (and trailing continuation) rather than
50            // silently truncating an overlong/non-canonical encoding.
51            if shift == 63 && b & !0x01 != 0 {
52                return Err(Error::invalid("varint exceeds 64 bits".into()));
53            }
54            result |= u64::from(b & 0x7F) << shift;
55            if b & 0x80 == 0 {
56                return Ok(result);
57            }
58            shift += 7;
59        }
60    }
61
62    //---- enum tags (consume one byte, TryFrom) --------------------------
63
64    /// Consume the next expression token byte.
65    pub fn read_expr_token(&mut self) -> Result<ExpressionEnum, Error> {
66        let b = self.inner.read_byte()?;
67        ExpressionEnum::try_from(b)
68            .map_err(|_| Error::invalid(format!("unknown WXF token byte 0x{:02X}", b)))
69    }
70
71    /// Consume a NumericArray element-type byte.
72    pub fn read_numeric_type(&mut self) -> Result<NumericArrayEnum, Error> {
73        let b = self.inner.read_byte()?;
74        NumericArrayEnum::try_from(b).map_err(|_| {
75            Error::invalid(format!("unknown NumericArray element type 0x{:02X}", b))
76        })
77    }
78
79    /// Consume a PackedArray element-type byte (numeric subset).
80    pub fn read_packed_type(&mut self) -> Result<PackedArrayEnum, Error> {
81        let b = self.inner.read_byte()?;
82        PackedArrayEnum::try_from(b).map_err(|_| {
83            Error::invalid(format!("unknown PackedArray element type 0x{:02X}", b))
84        })
85    }
86
87    //---- fixed-width integer / real payloads (tag already consumed) -----
88
89    /// Read an `Integer8` payload.
90    pub fn read_i8(&mut self) -> Result<i8, Error> {
91        Ok(self.inner.read_byte()? as i8)
92    }
93
94    /// Read an `Integer16` payload.
95    pub fn read_i16(&mut self) -> Result<i16, Error> {
96        let b = self.inner.read_bytes(2)?;
97        Ok(i16::from_le_bytes(b.try_into().unwrap()))
98    }
99
100    /// Read an `Integer32` payload.
101    pub fn read_i32(&mut self) -> Result<i32, Error> {
102        let b = self.inner.read_bytes(4)?;
103        Ok(i32::from_le_bytes(b.try_into().unwrap()))
104    }
105
106    /// Read an `Integer64` payload.
107    pub fn read_i64(&mut self) -> Result<i64, Error> {
108        let b = self.inner.read_bytes(8)?;
109        Ok(i64::from_le_bytes(b.try_into().unwrap()))
110    }
111
112    /// Read a `Real64` payload.
113    pub fn read_f64(&mut self) -> Result<f64, Error> {
114        let b = self.inner.read_bytes(8)?;
115        Ok(f64::from_le_bytes(b.try_into().unwrap()))
116    }
117
118    //---- length-prefixed payloads (tag already consumed) ----------------
119
120    /// Read a `String`/`Symbol`-shaped payload: varint length + UTF-8 bytes.
121    /// Zero-copy — returns a `&'de str` view into the underlying buffer, so it
122    /// serves both the owned path (`.to_owned()`) and borrowed fields (`&'de str`).
123    pub fn read_str(&mut self) -> Result<&'de str, Error> {
124        let len = self.read_varint()? as usize;
125        let bytes = self.inner.read_bytes(len)?;
126        std::str::from_utf8(bytes)
127            .map_err(|_| Error::invalid("payload not valid UTF-8".into()))
128    }
129
130    /// Read a complete `String` value (token + payload) into an owned `String`.
131    /// Used for keys/labels where the token has not been pre-consumed.
132    pub fn read_string(&mut self) -> Result<String, Error> {
133        match self.read_expr_token()? {
134            ExpressionEnum::String => Ok(self.read_str()?.to_owned()),
135            other => Err(Error::unexpected_token(&["String"], other)),
136        }
137    }
138
139    /// Read a `Symbol`/`BigInteger`/`BigReal` payload as an owned name/digit
140    /// string (`varint` length + UTF-8). The consumer parses it into the
141    /// appropriate value type.
142    pub fn read_symbol_name(&mut self) -> Result<String, Error> {
143        Ok(self.read_str()?.to_owned())
144    }
145
146    /// Read a `ByteArray` payload: varint length + raw bytes. Zero-copy — returns
147    /// a `&'de [u8]` view into the underlying buffer (owned path copies via
148    /// `.to_vec()`; borrowed `&'de [u8]` fields keep it).
149    pub fn read_byte_array(&mut self) -> Result<&'de [u8], Error> {
150        let len = self.read_varint()? as usize;
151        self.inner.read_bytes(len)
152    }
153
154    //---- arrays (tag already consumed) ----------------------------------
155
156    /// Read the body of a `NumericArray`/`PackedArray` token (tag already
157    /// consumed): element type + rank + dims + flat little-endian buffer.
158    /// Returns the element type, the dims, and the owned byte buffer.
159    pub fn read_numeric_array_parts(
160        &mut self,
161    ) -> Result<(NumericArrayEnum, Vec<usize>, Vec<u8>), Error> {
162        let dt = self.read_numeric_type()?;
163        let (dims, bytes) = self.read_array_body(dt.size_in_bytes())?;
164        Ok((dt, dims, bytes))
165    }
166
167    /// Read an array shape header: rank varint + `rank` dim varints. Returns the
168    /// dims and the **flat byte count** (`prod(dims) * elem_size`).
169    ///
170    /// Both quantities come from untrusted input, so: the dims vector caps its
171    /// pre-allocation (`capped_capacity`), and the byte count is computed with
172    /// overflow checking — a wrapping `prod(dims) * elem_size` would otherwise
173    /// yield a small count and silently read a truncated array instead of
174    /// erroring.
175    pub fn read_array_shape(
176        &mut self,
177        elem_size: usize,
178    ) -> Result<(Vec<usize>, usize), Error> {
179        let rank = self.read_varint()? as usize;
180        let mut dims = Vec::with_capacity(crate::capped_capacity(rank));
181        for _ in 0..rank {
182            dims.push(self.read_varint()? as usize);
183        }
184        let byte_count = dims
185            .iter()
186            .try_fold(1usize, |acc, &d| acc.checked_mul(d))
187            .and_then(|count| count.checked_mul(elem_size))
188            .ok_or_else(|| Error::invalid("array byte count overflow".into()))?;
189        Ok((dims, byte_count))
190    }
191
192    /// Shared array tail: [`read_array_shape`][Self::read_array_shape] followed by
193    /// the flat little-endian byte buffer, returned as an owned `Vec<u8>`.
194    pub fn read_array_body(
195        &mut self,
196        elem_size: usize,
197    ) -> Result<(Vec<usize>, Vec<u8>), Error> {
198        let (dims, byte_count) = self.read_array_shape(elem_size)?;
199        let bytes = self.inner.read_bytes(byte_count)?.to_vec();
200        Ok((dims, bytes))
201    }
202
203    //---- association rules ----------------------------------------------
204
205    /// Read one `Rule` / `RuleDelayed` token; returns the `delayed` flag.
206    pub fn read_rule(&mut self) -> Result<bool, Error> {
207        match self.read_expr_token()? {
208            ExpressionEnum::Rule => Ok(false),
209            ExpressionEnum::RuleDelayed => Ok(true),
210            other => Err(Error::unexpected_token(&["Rule", "RuleDelayed"], other)),
211        }
212    }
213
214    //---- skip -----------------------------------------------------------
215
216    /// Read one complete value at the current position and discard it. Used to
217    /// drop an unknown Association key's value, or a Function head whose shape
218    /// isn't validated.
219    pub fn skip(&mut self) -> Result<(), Error> {
220        let tok = self.read_expr_token()?;
221        self.skip_body(tok)
222    }
223
224    fn skip_body(&mut self, tok: ExpressionEnum) -> Result<(), Error> {
225        match tok {
226            ExpressionEnum::Integer8 => {
227                self.read_i8()?;
228            },
229            ExpressionEnum::Integer16 => {
230                self.read_i16()?;
231            },
232            ExpressionEnum::Integer32 => {
233                self.read_i32()?;
234            },
235            ExpressionEnum::Integer64 => {
236                self.read_i64()?;
237            },
238            ExpressionEnum::Real64 => {
239                self.read_f64()?;
240            },
241            ExpressionEnum::String
242            | ExpressionEnum::Symbol
243            | ExpressionEnum::ByteArray
244            | ExpressionEnum::BigInteger
245            | ExpressionEnum::BigReal => {
246                let len = self.read_varint()? as usize;
247                self.inner.read_bytes(len)?;
248            },
249            ExpressionEnum::NumericArray | ExpressionEnum::PackedArray => {
250                // element-type byte (numeric subset shares wire bytes)
251                let dt = self.read_numeric_type()?;
252                let (_dims, byte_count) = self.read_array_shape(dt.size_in_bytes())?;
253                self.inner.read_bytes(byte_count)?;
254            },
255            ExpressionEnum::Function => {
256                let n = self.read_varint()?;
257                self.skip()?; // head
258                for _ in 0..n {
259                    self.skip()?;
260                }
261            },
262            ExpressionEnum::Association => {
263                let n = self.read_varint()?;
264                for _ in 0..n {
265                    self.read_rule()?;
266                    self.skip()?; // key
267                    self.skip()?; // value
268                }
269            },
270            // A Rule where a value was expected: "any token but this".
271            other @ (ExpressionEnum::Rule | ExpressionEnum::RuleDelayed) => {
272                return Err(Error::unexpected_token(&[], other))
273            },
274        }
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::reader::SliceReader;
283    use crate::wxf::writer::WxfWriter;
284
285    fn varint_roundtrip(n: u64) -> u64 {
286        let mut w = WxfWriter::new(Vec::new());
287        w.write_varint(n).unwrap();
288        let bytes = w.into_inner();
289        WxfReader::new(SliceReader::new(&bytes)).read_varint().unwrap()
290    }
291
292    #[test]
293    fn varint_roundtrips_over_full_range() {
294        for n in [0u64, 1, 127, 128, 16383, 16384, 1_000_000, u64::MAX] {
295            assert_eq!(varint_roundtrip(n), n);
296        }
297    }
298
299    #[test]
300    fn varint_rejects_overlong_encoding() {
301        // 11 continuation bytes: the 10th group already overflows 64 bits.
302        let bytes = [0x80u8; 11];
303        assert!(WxfReader::new(SliceReader::new(&bytes)).read_varint().is_err());
304    }
305
306    #[test]
307    fn varint_rejects_high_bits_in_final_group() {
308        // 9 continuation groups (shift 63) then a final group with a bit above
309        // bit 63 set — must error rather than silently truncate.
310        let mut bytes = vec![0x80u8; 9];
311        bytes.push(0x02);
312        assert!(WxfReader::new(SliceReader::new(&bytes)).read_varint().is_err());
313    }
314}