wtns_file/
lib.rs

1//! Implementation of binary .wtns file parser/serializer.
2//! According to https://github.com/iden3/snarkjs/blob/master/src/wtns_utils.js
3
4use std::io::{Error, ErrorKind, Read, Result, Write};
5
6use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
7
8const MAGIC: &[u8; 4] = b"wtns";
9
10#[derive(Debug, PartialEq)]
11pub struct WtnsFile<const FS: usize> {
12    pub version: u32,
13    pub header: Header<FS>,
14    pub witness: Witness<FS>,
15}
16
17impl<const FS: usize> WtnsFile<FS> {
18    pub fn from_vec(witness: Vec<FieldElement<FS>>, prime: FieldElement<FS>) -> Self {
19        WtnsFile {
20            version: 1,
21            header: Header {
22                field_size: FS as u32,
23                prime,
24                witness_len: witness.len() as u32,
25            },
26            witness: Witness(witness),
27        }
28    }
29
30    pub fn read<R: Read>(mut r: R) -> Result<Self> {
31        let mut magic = [0u8; 4];
32        r.read_exact(&mut magic)?;
33
34        if magic != *MAGIC {
35            return Err(Error::new(ErrorKind::InvalidData, "Invalid magic number"));
36        }
37
38        let version = r.read_u32::<LittleEndian>()?;
39        if version > 2 {
40            return Err(Error::new(ErrorKind::InvalidData, "Unsupported version"));
41        }
42
43        let num_sections = r.read_u32::<LittleEndian>()?;
44        if num_sections > 2 {
45            return Err(Error::new(
46                ErrorKind::InvalidData,
47                "Number of sections >2 is not supported",
48            ));
49        }
50
51        let header = Header::read(&mut r)?;
52        let witness = Witness::read(&mut r, &header)?;
53
54        Ok(WtnsFile {
55            version,
56            header,
57            witness,
58        })
59    }
60
61    pub fn write<W: Write>(&self, mut w: W) -> Result<()> {
62        w.write_all(MAGIC)?;
63        w.write_u32::<LittleEndian>(self.version)?;
64        w.write_u32::<LittleEndian>(2)?;
65        self.header.write(&mut w)?;
66        self.witness.write(&mut w)?;
67
68        Ok(())
69    }
70}
71
72#[derive(Debug, PartialEq)]
73pub struct Header<const FS: usize> {
74    pub field_size: u32,
75    pub prime: FieldElement<FS>,
76    pub witness_len: u32,
77}
78
79impl<const FS: usize> Header<FS> {
80    pub fn read<R: Read>(mut r: R) -> Result<Self> {
81        let sec_type = SectionType::read(&mut r)?;
82        if sec_type != SectionType::Header {
83            return Err(Error::new(
84                ErrorKind::InvalidData,
85                "Invalid section type: expected header",
86            ));
87        }
88
89        let sec_size = r.read_u64::<LittleEndian>()?;
90        if sec_size != 4 + FS as u64 + 4 {
91            return Err(Error::new(
92                ErrorKind::InvalidData,
93                "Invalid header section size",
94            ));
95        }
96
97        let field_size = r.read_u32::<LittleEndian>()?;
98        let prime = FieldElement::read(&mut r)?;
99
100        if field_size != FS as u32 {
101            return Err(Error::new(ErrorKind::InvalidData, "Wrong field size"));
102        }
103
104        let witness_len = r.read_u32::<LittleEndian>()?;
105
106        Ok(Header {
107            field_size,
108            prime,
109            witness_len,
110        })
111    }
112
113    pub fn write<W: Write>(&self, mut w: W) -> Result<()> {
114        SectionType::Header.write(&mut w)?;
115
116        let sec_size = 4 + FS as u64 + 4;
117        w.write_u64::<LittleEndian>(sec_size)?;
118
119        w.write_u32::<LittleEndian>(FS as u32)?;
120        self.prime.write(&mut w)?;
121        w.write_u32::<LittleEndian>(self.witness_len)?;
122
123        Ok(())
124    }
125}
126
127#[derive(Debug, PartialEq)]
128pub struct Witness<const FS: usize>(pub Vec<FieldElement<FS>>);
129
130impl<const FS: usize> Witness<FS> {
131    pub fn read<R: Read>(mut r: R, header: &Header<FS>) -> Result<Self> {
132        let sec_type = SectionType::read(&mut r)?;
133        if sec_type != SectionType::Witness {
134            return Err(Error::new(ErrorKind::InvalidData, "Invalid section type: expected witness"));
135        }
136        let sec_size = r.read_u64::<LittleEndian>()?;
137
138        if sec_size != header.witness_len as u64 * FS as u64 {
139            return Err(Error::new(
140                ErrorKind::InvalidData,
141                "Invalid witness section size",
142            ));
143        }
144
145        let mut witness = Vec::with_capacity(header.witness_len as usize);
146        for _ in 0..header.witness_len {
147            witness.push(FieldElement::read(&mut r)?);
148        }
149
150        Ok(Witness(witness))
151    }
152
153    fn write<W: Write>(&self, mut w: W) -> Result<()> {
154        SectionType::Witness.write(&mut w)?;
155
156        let sec_size = (self.0.len() * FS) as u64;
157        w.write_u64::<LittleEndian>(sec_size)?;
158
159        for e in &self.0 {
160            e.write(&mut w)?;
161        }
162
163        Ok(())
164    }
165}
166
167#[derive(Debug, Eq, PartialEq, Clone, Copy)]
168#[repr(u32)]
169pub enum SectionType {
170    Header = 1,
171    Witness = 2,
172    Unknown = u32::MAX,
173}
174
175impl SectionType {
176    fn read<R: Read>(mut r: R) -> Result<Self> {
177        let num = r.read_u32::<LittleEndian>()?;
178
179        let ty = match num {
180            1 => SectionType::Header,
181            2 => SectionType::Witness,
182            _ => SectionType::Unknown,
183        };
184
185        Ok(ty)
186    }
187
188    fn write<W: Write>(&self, mut w: W) -> Result<()> {
189        w.write_u32::<LittleEndian>(*self as u32)?;
190
191        Ok(())
192    }
193}
194
195#[derive(Debug, PartialEq, Eq)]
196pub struct FieldElement<const FS: usize>([u8; FS]);
197
198impl<const FS: usize> FieldElement<FS> {
199    pub fn as_bytes(&self) -> &[u8] {
200        &self.0[..]
201    }
202
203    fn read<R: Read>(mut r: R) -> Result<Self> {
204        let mut buf = [0; FS];
205        r.read_exact(&mut buf)?;
206
207        Ok(FieldElement(buf))
208    }
209
210    fn write<W: Write>(&self, mut w: W) -> Result<()> {
211        w.write_all(&self.0[..])
212    }
213}
214
215impl<const FS: usize> From<[u8; FS]> for FieldElement<FS> {
216    fn from(array: [u8; FS]) -> Self {
217        FieldElement(array)
218    }
219}
220
221impl<const FS: usize> std::ops::Deref for FieldElement<FS> {
222    type Target = [u8; FS];
223
224    fn deref(&self) -> &Self::Target {
225        &self.0
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use std::io::Cursor;
233
234    const FS: usize = 32;
235
236    fn fe() -> FieldElement<FS> {
237        FieldElement::from([1,0,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1])
238    }
239
240    #[test]
241    fn test() {
242        let file = WtnsFile::<FS>::from_vec(vec![fe(), fe(), fe()], fe());
243        let mut data = Vec::new();
244        file.write(&mut data).unwrap();
245
246        let new_file = WtnsFile::read(Cursor::new(data)).unwrap();
247
248        assert_eq!(file, new_file);
249    }
250}