1use std::io::{Error, ErrorKind, Read, Result, Write};
5
6use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
7
8const MAGIC: &[u8; 4] = b"wtns";
9
10#[derive(Debug, PartialEq)]
11pub struct WtnsFile<const FS: usize> {
12 pub version: u32,
13 pub header: Header<FS>,
14 pub witness: Witness<FS>,
15}
16
17impl<const FS: usize> WtnsFile<FS> {
18 pub fn from_vec(witness: Vec<FieldElement<FS>>, prime: FieldElement<FS>) -> Self {
19 WtnsFile {
20 version: 1,
21 header: Header {
22 field_size: FS as u32,
23 prime,
24 witness_len: witness.len() as u32,
25 },
26 witness: Witness(witness),
27 }
28 }
29
30 pub fn read<R: Read>(mut r: R) -> Result<Self> {
31 let mut magic = [0u8; 4];
32 r.read_exact(&mut magic)?;
33
34 if magic != *MAGIC {
35 return Err(Error::new(ErrorKind::InvalidData, "Invalid magic number"));
36 }
37
38 let version = r.read_u32::<LittleEndian>()?;
39 if version > 2 {
40 return Err(Error::new(ErrorKind::InvalidData, "Unsupported version"));
41 }
42
43 let num_sections = r.read_u32::<LittleEndian>()?;
44 if num_sections > 2 {
45 return Err(Error::new(
46 ErrorKind::InvalidData,
47 "Number of sections >2 is not supported",
48 ));
49 }
50
51 let header = Header::read(&mut r)?;
52 let witness = Witness::read(&mut r, &header)?;
53
54 Ok(WtnsFile {
55 version,
56 header,
57 witness,
58 })
59 }
60
61 pub fn write<W: Write>(&self, mut w: W) -> Result<()> {
62 w.write_all(MAGIC)?;
63 w.write_u32::<LittleEndian>(self.version)?;
64 w.write_u32::<LittleEndian>(2)?;
65 self.header.write(&mut w)?;
66 self.witness.write(&mut w)?;
67
68 Ok(())
69 }
70}
71
72#[derive(Debug, PartialEq)]
73pub struct Header<const FS: usize> {
74 pub field_size: u32,
75 pub prime: FieldElement<FS>,
76 pub witness_len: u32,
77}
78
79impl<const FS: usize> Header<FS> {
80 pub fn read<R: Read>(mut r: R) -> Result<Self> {
81 let sec_type = SectionType::read(&mut r)?;
82 if sec_type != SectionType::Header {
83 return Err(Error::new(
84 ErrorKind::InvalidData,
85 "Invalid section type: expected header",
86 ));
87 }
88
89 let sec_size = r.read_u64::<LittleEndian>()?;
90 if sec_size != 4 + FS as u64 + 4 {
91 return Err(Error::new(
92 ErrorKind::InvalidData,
93 "Invalid header section size",
94 ));
95 }
96
97 let field_size = r.read_u32::<LittleEndian>()?;
98 let prime = FieldElement::read(&mut r)?;
99
100 if field_size != FS as u32 {
101 return Err(Error::new(ErrorKind::InvalidData, "Wrong field size"));
102 }
103
104 let witness_len = r.read_u32::<LittleEndian>()?;
105
106 Ok(Header {
107 field_size,
108 prime,
109 witness_len,
110 })
111 }
112
113 pub fn write<W: Write>(&self, mut w: W) -> Result<()> {
114 SectionType::Header.write(&mut w)?;
115
116 let sec_size = 4 + FS as u64 + 4;
117 w.write_u64::<LittleEndian>(sec_size)?;
118
119 w.write_u32::<LittleEndian>(FS as u32)?;
120 self.prime.write(&mut w)?;
121 w.write_u32::<LittleEndian>(self.witness_len)?;
122
123 Ok(())
124 }
125}
126
127#[derive(Debug, PartialEq)]
128pub struct Witness<const FS: usize>(pub Vec<FieldElement<FS>>);
129
130impl<const FS: usize> Witness<FS> {
131 pub fn read<R: Read>(mut r: R, header: &Header<FS>) -> Result<Self> {
132 let sec_type = SectionType::read(&mut r)?;
133 if sec_type != SectionType::Witness {
134 return Err(Error::new(ErrorKind::InvalidData, "Invalid section type: expected witness"));
135 }
136 let sec_size = r.read_u64::<LittleEndian>()?;
137
138 if sec_size != header.witness_len as u64 * FS as u64 {
139 return Err(Error::new(
140 ErrorKind::InvalidData,
141 "Invalid witness section size",
142 ));
143 }
144
145 let mut witness = Vec::with_capacity(header.witness_len as usize);
146 for _ in 0..header.witness_len {
147 witness.push(FieldElement::read(&mut r)?);
148 }
149
150 Ok(Witness(witness))
151 }
152
153 fn write<W: Write>(&self, mut w: W) -> Result<()> {
154 SectionType::Witness.write(&mut w)?;
155
156 let sec_size = (self.0.len() * FS) as u64;
157 w.write_u64::<LittleEndian>(sec_size)?;
158
159 for e in &self.0 {
160 e.write(&mut w)?;
161 }
162
163 Ok(())
164 }
165}
166
167#[derive(Debug, Eq, PartialEq, Clone, Copy)]
168#[repr(u32)]
169pub enum SectionType {
170 Header = 1,
171 Witness = 2,
172 Unknown = u32::MAX,
173}
174
175impl SectionType {
176 fn read<R: Read>(mut r: R) -> Result<Self> {
177 let num = r.read_u32::<LittleEndian>()?;
178
179 let ty = match num {
180 1 => SectionType::Header,
181 2 => SectionType::Witness,
182 _ => SectionType::Unknown,
183 };
184
185 Ok(ty)
186 }
187
188 fn write<W: Write>(&self, mut w: W) -> Result<()> {
189 w.write_u32::<LittleEndian>(*self as u32)?;
190
191 Ok(())
192 }
193}
194
195#[derive(Debug, PartialEq, Eq)]
196pub struct FieldElement<const FS: usize>([u8; FS]);
197
198impl<const FS: usize> FieldElement<FS> {
199 pub fn as_bytes(&self) -> &[u8] {
200 &self.0[..]
201 }
202
203 fn read<R: Read>(mut r: R) -> Result<Self> {
204 let mut buf = [0; FS];
205 r.read_exact(&mut buf)?;
206
207 Ok(FieldElement(buf))
208 }
209
210 fn write<W: Write>(&self, mut w: W) -> Result<()> {
211 w.write_all(&self.0[..])
212 }
213}
214
215impl<const FS: usize> From<[u8; FS]> for FieldElement<FS> {
216 fn from(array: [u8; FS]) -> Self {
217 FieldElement(array)
218 }
219}
220
221impl<const FS: usize> std::ops::Deref for FieldElement<FS> {
222 type Target = [u8; FS];
223
224 fn deref(&self) -> &Self::Target {
225 &self.0
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use std::io::Cursor;
233
234 const FS: usize = 32;
235
236 fn fe() -> FieldElement<FS> {
237 FieldElement::from([1,0,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1])
238 }
239
240 #[test]
241 fn test() {
242 let file = WtnsFile::<FS>::from_vec(vec![fe(), fe(), fe()], fe());
243 let mut data = Vec::new();
244 file.write(&mut data).unwrap();
245
246 let new_file = WtnsFile::read(Cursor::new(data)).unwrap();
247
248 assert_eq!(file, new_file);
249 }
250}