tnj_pcc/encoding/
read.rs

1//! Module containing deserialization for proofs.
2
3use crate::encoding::iter::UntrustedSizeIterator;
4use crate::Proof;
5use arch::reg::Register;
6use smol_str::SmolStr;
7use std::collections::HashMap;
8use std::io;
9use std::io::Read;
10use sym::{BinOpTy, BoolOpTy, CmpTy, Expr, Term};
11use thiserror::Error;
12use types::Type;
13
14/// Read a proof from a reader.
15pub fn read<R>(r: &mut R) -> Result<Proof, Error>
16where
17    R: ?Sized + Read,
18{
19    let mut proof = Proof::default();
20
21    // first we read the strings
22    let n_strings = leb128::read::unsigned(r)?;
23    // we do not reserve space for the strings, as this would allow an attacker to easily crash our
24    // program, making it run out of memory.
25    let strings = UntrustedSizeIterator::new(n_strings)
26        .map(|i| Ok((i, read_str(r)?)))
27        .collect::<Result<_, Error>>()?;
28
29    let assertions_len = leb128::read::unsigned(r)?;
30
31    proof.assertions = UntrustedSizeIterator::new(assertions_len)
32        .map(|_| {
33            let addr = leb128::read::unsigned(r)?;
34            let num_exprs = leb128::read::unsigned(r)?;
35            let expressions = UntrustedSizeIterator::new(num_exprs)
36                .map(|_| read_expr(r, &strings))
37                .collect::<Result<Vec<_>, Error>>()?;
38            Ok((addr, expressions))
39        })
40        .collect::<Result<HashMap<_, _>, Error>>()?;
41
42    Ok(proof)
43}
44
45fn read_expr<R>(
46    r: &mut R,
47    strings: &HashMap<u64, SmolStr>,
48) -> Result<Expr, crate::encoding::read::Error>
49where
50    R: ?Sized + Read,
51{
52    let ty = read_ty(r)?;
53    let tag = leb128::read::unsigned(r)?;
54
55    let expr = match tag {
56        0 => {
57            let id = leb128::read::unsigned(r)?
58                .try_into()
59                .map_err(|_| Error::Overflow("expected u32 unknown id"))?;
60            Expr::Unknown(id, ty)
61        }
62        1 => {
63            let val = leb128::read::unsigned(r)?;
64            Expr::Const { val, ty }
65        }
66        2 => Expr::Load {
67            addr: Box::new(read_expr(r, strings)?),
68            ty,
69        },
70        3 => {
71            let op = match leb128::read::unsigned(r)? {
72                0 => BinOpTy::WrappingAdd,
73                1 => BinOpTy::WrappingSub,
74                n => return Err(Error::InvalidType(n, "BinOpTy")),
75            };
76            Expr::BinOp {
77                lhs: Box::new(read_expr(r, strings)?),
78                rhs: Box::new(read_expr(r, strings)?),
79                op,
80                ty,
81            }
82        }
83        4 => Expr::Term(read_term(r, strings)?),
84        5 => {
85            let cmp_ty = match leb128::read::unsigned(r)? {
86                0 => CmpTy::Equal,
87                1 => CmpTy::NotEqual,
88                2 => CmpTy::Less,
89                3 => CmpTy::LessEqual,
90                4 => CmpTy::Greater,
91                5 => CmpTy::GreaterEqual,
92                t => return Err(Error::InvalidType(t, "CmpTy")),
93            };
94            Expr::Cmp {
95                lhs: Box::new(read_expr(r, strings)?),
96                rhs: Box::new(read_expr(r, strings)?),
97                ty: cmp_ty,
98            }
99        }
100        6 => {
101            let op = match leb128::read::unsigned(r)? {
102                0 => BoolOpTy::And,
103                1 => BoolOpTy::Or,
104                2 => BoolOpTy::Imp,
105                3 => BoolOpTy::BiImp,
106                t => return Err(Error::InvalidType(t, "BoolOpTy")),
107            };
108            Expr::BoolOp {
109                lhs: Box::new(read_expr(r, strings)?),
110                rhs: Box::new(read_expr(r, strings)?),
111                ty: op,
112            }
113        }
114        7 => Expr::Not {
115            inner: Box::new(read_expr(r, strings)?),
116        },
117        8 => {
118            let string_id = leb128::read::unsigned(r)?;
119            let name = strings
120                .get(&string_id)
121                .ok_or(Error::InvalidStringId(string_id))?
122                .clone();
123            Expr::Reg(Register { name, ty })
124        }
125        t => return Err(Error::InvalidTag(t)),
126    };
127
128    Ok(expr)
129}
130
131fn read_term<R>(
132    r: &mut R,
133    strings: &HashMap<u64, SmolStr>,
134) -> Result<Term, crate::encoding::read::Error>
135where
136    R: ?Sized + Read,
137{
138    let string_id = leb128::read::unsigned(r)?;
139    let name = strings
140        .get(&string_id)
141        .ok_or(Error::InvalidStringId(string_id))?
142        .clone();
143    let ty = read_ty(r)?;
144    let len = leb128::read::unsigned(r)?;
145    let args = UntrustedSizeIterator::new(len)
146        .map(|_| read_expr(r, strings))
147        .collect::<Result<Vec<_>, Error>>()?;
148
149    Ok(Term { name, args, ty })
150}
151
152fn read_ty<R>(r: &mut R) -> Result<Type, Error>
153where
154    R: ?Sized + Read,
155{
156    let ty = leb128::read::unsigned(r)?;
157    match ty {
158        n if n <= 64 && n.is_power_of_two() => Ok(Type::int(n as u8)),
159        n => Err(Error::InvalidType(n, "type")),
160    }
161}
162
163fn read_str<R>(r: &mut R) -> Result<SmolStr, Error>
164where
165    R: ?Sized + Read,
166{
167    let mut len = leb128::read::unsigned(r)? as usize;
168
169    // we read the string in chunks of max 64 bytes
170    const CHUNK_SIZE: usize = 64;
171
172    // Since we don't trust the length, we cannot pre-allocate using the full length.
173    // We will pre-allocate using the chunk size instead
174    let mut bytes = Vec::with_capacity(usize::min(len, CHUNK_SIZE));
175
176    let mut buf = [0u8; CHUNK_SIZE];
177    while len > 0 {
178        let bytes_to_read = usize::min(len, CHUNK_SIZE);
179
180        let buf_slice = &mut buf[0..bytes_to_read];
181        r.read_exact(buf_slice)?;
182        bytes.extend_from_slice(buf_slice);
183
184        len -= bytes_to_read;
185    }
186
187    let s = std::str::from_utf8(&bytes)
188        .map_err(|_| Error::Utf8Error(String::from_utf8_lossy(&bytes).to_string()))?;
189
190    Ok(SmolStr::from(s))
191}
192
193/// Error type for reading proofs.
194#[derive(Debug, Error)]
195pub enum Error {
196    /// An unexpected end of input was encountered.
197    #[error("unexpected end of input")]
198    UnexpectedEof,
199    /// Wrapper around an leb128 error.
200    #[error("leb128 error: {0}")]
201    Leb128(#[from] leb128::read::Error),
202    /// Wrapper around an io error.
203    #[error("io error: {0}")]
204    Io(#[from] io::Error),
205    /// Invalid type
206    #[error("invalid {1}: {0}")]
207    InvalidType(u64, &'static str),
208    /// Invalid expression tag
209    #[error("invalid expression tag: {0}")]
210    InvalidTag(u64),
211    /// Overflow, value does not fit into type
212    #[error("overflow: {0}")]
213    Overflow(&'static str),
214    /// Invalid Utf-8
215    #[error("Invalid utf-8: {0}")]
216    Utf8Error(String),
217    /// Invalid string id
218    #[error("Invalid string id {0}, does not exist in strings section")]
219    InvalidStringId(u64),
220}
221
222#[cfg(test)]
223mod tests {
224    use crate::read::read;
225    use arch::reg::Register;
226    use sym::Expr;
227    use types::I64;
228
229    #[test]
230    fn test_read() {
231        let bytes = [
232            // strings
233            0x0, // no strings
234            // assertions
235            0x2, // two entries
236            // first entry
237            0x0, // at offset 0
238            0x1, // one assertion
239            // assertion
240            0x40, // expression of type I64 (0x40)
241            0x1,  // constant
242            0x2a, // value 42
243            // second entry
244            0x1, // at offset 1
245            0x1, // one assertion
246            // assertion
247            0x40, // expression of type I64 (0x40)
248            0x1,  // constant
249            0x2b, // value 43
250        ];
251        let mut cursor = std::io::Cursor::new(&bytes);
252        let proof = read(&mut cursor).unwrap();
253        assert_eq!(proof.assertions.len(), 2);
254        assert_eq!(
255            proof.assertions[&0],
256            vec![Expr::Const {
257                val: 42,
258                ty: types::I64
259            }]
260        );
261        assert_eq!(
262            proof.assertions[&1],
263            vec![Expr::Const {
264                val: 43,
265                ty: types::I64
266            }]
267        );
268    }
269
270    #[test]
271    fn test_string() {
272        let bytes = [
273            // strings:
274            0x1, // one string
275            // string with id 0
276            0xc, // string with length 12
277            // h  e     l     l     o     ,
278            0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, //
279            // w  o     r     l     d     !
280            0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, //
281            // assertions
282            0x1, // one entry
283            0xc, // at offset 12
284            // entry:
285            0x1, // a single assertion
286            // assertion:
287            0x40, // expression of type I64 (0x40)
288            0x8,  // expression of type reg
289            0x0,  // register name: string with id 0
290        ];
291        let mut cursor = std::io::Cursor::new(&bytes);
292        let proof = read(&mut cursor).unwrap();
293        assert_eq!(proof.assertions.len(), 1);
294        assert_eq!(
295            proof.assertions[&12],
296            vec![Expr::Reg(Register::new("hello, world", I64))]
297        );
298    }
299}