1pub mod wirehair {
2 use std::fmt::{Display, Error, Formatter};
3 use std::os::raw::{c_int, c_void};
4 use std::ptr::null;
5
6 #[repr(C)]
7 enum WirehairResultCode {
8 Success = 0,
10 NeedMore = 1,
12 InvalidInput = 2,
15 BadDenseSeed = 3,
17 BadPeelSeed = 4,
19 BadInputSmallN = 5,
22 BadInputLargeN = 6,
25 ExtraInsufficient = 7,
27 Error = 8,
29 OOM = 9,
31 UnsupportedPlatform = 10,
33 Count,
34 Padding = 0x7fff_ffff,
36 }
38
39 #[link(name = "wirehair")]
40 extern "C" {
41 fn wirehair_init_(version: c_int) -> WirehairResultCode;
42 fn wirehair_encoder_create(
43 reuse_codec_opt: *const c_void,
44 message: *const u8,
45 message_size_bytes: u64,
46 block_size_bytes: u32,
47 ) -> *const c_void;
48 fn wirehair_encode(
49 codec: *const c_void,
50 block_id: u64,
51 block: *mut u8,
52 block_size: u32,
53 block_out_bytes: &mut u32,
54 ) -> WirehairResultCode;
55 fn wirehair_decoder_create(
56 reuse_codec_opt: *const c_void,
57 message_size_bytes: u64,
58 block_size_bytes: u32,
59 ) -> *const c_void;
60 fn wirehair_decode(
61 codec: *const c_void,
62 block_id: u64,
63 block: *const u8,
64 block_out_bytes: u32,
65 ) -> WirehairResultCode;
66 fn wirehair_recover(
67 codec: *const c_void,
68 message: *mut u8,
69 message_size_bytes: u64,
70 ) -> WirehairResultCode;
71 fn wirehair_decoder_becomes_encoder(codec: *const c_void) -> WirehairResultCode;
72 fn wirehair_free(codec: *const c_void) -> c_void;
73 }
74
75 #[derive(Debug, PartialEq)]
76 pub enum WirehairError {
77 InvalidInput,
78 BadDenseSeed,
79 BadPeelSeed,
80 BadInputSmallN,
81 BadInputLargeN,
82 ExtraInsufficient,
83 Error,
84 OOM,
85 UnsupportedPlatform,
86 }
87
88 impl Display for WirehairError {
89 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
90 match *self {
91 WirehairError::InvalidInput => write!(f, "A function parameter was invalid"),
92 WirehairError::BadDenseSeed => write!(f, "Encoder needs a better dense seed"),
93 WirehairError::BadPeelSeed => write!(f, "Encoder needs a better peel seed"),
94 WirehairError::BadInputSmallN => write!(
95 f,
96 "Too less blocks! Try reducing block size or use a larger message"
97 ),
98 WirehairError::BadInputLargeN => write!(
99 f,
100 "Too many blocks! Try increasing block_size or use a smaller message"
101 ),
102 WirehairError::ExtraInsufficient => write!(
103 f,
104 "Not enough extra rows to solve it, possibly corrupted data"
105 ),
106 WirehairError::Error => write!(f, "Unexpected error"),
107 WirehairError::OOM => write!(f, "Out of memory"),
108 WirehairError::UnsupportedPlatform => write!(f, "Platform is not supported yet"),
109 }
110 }
111 }
112
113 #[derive(Debug, PartialEq)]
114 pub enum WirehairResult {
115 Success,
116 NeedMore,
117 Internal,
118 }
119
120 fn parse_wirehair_result(result: WirehairResultCode) -> Result<WirehairResult, WirehairError> {
121 match result {
122 WirehairResultCode::InvalidInput => Err(WirehairError::InvalidInput),
123 WirehairResultCode::BadDenseSeed => Err(WirehairError::BadDenseSeed),
124 WirehairResultCode::BadPeelSeed => Err(WirehairError::BadPeelSeed),
125 WirehairResultCode::BadInputSmallN => Err(WirehairError::BadInputSmallN),
126 WirehairResultCode::BadInputLargeN => Err(WirehairError::BadInputLargeN),
127 WirehairResultCode::ExtraInsufficient => Err(WirehairError::ExtraInsufficient),
128 WirehairResultCode::Error => Err(WirehairError::Error),
129 WirehairResultCode::OOM => Err(WirehairError::OOM),
130 WirehairResultCode::UnsupportedPlatform => Err(WirehairError::UnsupportedPlatform),
131 WirehairResultCode::Success => Ok(WirehairResult::Success),
132 WirehairResultCode::NeedMore => Ok(WirehairResult::NeedMore),
133 _ => Ok(WirehairResult::Internal),
134 }
135 }
136
137 pub fn wirehair_init() -> Result<(), WirehairError> {
138 let result = unsafe { parse_wirehair_result(wirehair_init_(2)) };
139 match result {
140 Ok(_r) => Ok(()),
141 Err(e) => Err(e),
142 }
143 }
144
145 pub fn wirehair_decoder_to_encoder(
146 decoder: WirehairDecoder,
147 ) -> Result<WirehairEncoder, WirehairError> {
148 let result = unsafe { wirehair_decoder_becomes_encoder(decoder.native_handler) };
149
150 match parse_wirehair_result(result) {
151 Ok(_) => Ok(WirehairEncoder {
152 native_handler: decoder.native_handler,
153 }),
154 Err(e) => Err(e),
155 }
156 }
157
158 pub struct WirehairEncoder {
159 native_handler: *const c_void,
160 }
161
162 impl WirehairEncoder {
163 pub fn new(
164 message: &[u8],
165 message_size_bytes: u64,
166 block_size_bytes: u32,
167 ) -> WirehairEncoder {
168 WirehairEncoder {
169 native_handler: unsafe {
170 wirehair_encoder_create(
171 null::<c_void>(),
172 message.as_ptr(),
173 message_size_bytes,
174 block_size_bytes,
175 )
176 },
177 }
178 }
179
180 pub fn encode(
181 &self,
182 block_id: u64,
183 block: &mut [u8],
184 block_size: u32,
185 block_out_bytes: &mut u32,
186 ) -> Result<WirehairResult, WirehairError> {
187 let result = unsafe {
188 wirehair_encode(
189 self.native_handler,
190 block_id,
191 block.as_mut_ptr(),
192 block_size,
193 block_out_bytes,
194 )
195 };
196
197 parse_wirehair_result(result)
198 }
199 }
200
201 impl Drop for WirehairEncoder {
202 fn drop(&mut self) {
203 unsafe { wirehair_free(self.native_handler) };
204 }
205 }
206
207 pub struct WirehairDecoder {
208 native_handler: *const c_void,
209 }
210
211 impl WirehairDecoder {
212 pub fn new(message_size_bytes: u64, block_size_bytes: u32) -> WirehairDecoder {
213 WirehairDecoder {
214 native_handler: unsafe {
215 wirehair_decoder_create(null::<c_void>(), message_size_bytes, block_size_bytes)
216 },
217 }
218 }
219
220 pub fn decode(
221 &self,
222 block_id: u64,
223 block: &[u8],
224 block_out_size_bytes: u32,
225 ) -> Result<WirehairResult, WirehairError> {
226 let result = unsafe {
227 wirehair_decode(
228 self.native_handler,
229 block_id,
230 block.as_ptr(),
231 block_out_size_bytes,
232 )
233 };
234
235 parse_wirehair_result(result)
236 }
237
238 pub fn recover(
239 &self,
240 message: &mut [u8],
241 message_size_bytes: u64,
242 ) -> Result<WirehairResult, WirehairError> {
243 let result = unsafe {
244 wirehair_recover(
245 self.native_handler,
246 message.as_mut_ptr(),
247 message_size_bytes,
248 )
249 };
250
251 parse_wirehair_result(result)
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::wirehair::*;
259
260 #[test]
261 fn basic_flow_works() {
262 assert!(wirehair_init().is_ok());
263
264 let mut message = [0u8; 500];
265 for i in 0..500 {
266 message[i] = i as u8
267 }
268
269 let encoder = WirehairEncoder::new(&message, 500, 50);
270 let decoder = WirehairDecoder::new(500, 50);
271
272 let mut block_id = 0;
273
274 loop {
275 let mut block = [0u8; 50];
276 let mut block_out_bytes: u32 = 0;
277 let result = encoder.encode(block_id, &mut block, 50, &mut block_out_bytes);
278 assert!(result.is_ok());
279
280 if block_id % 5 == 0 {
281 block_id += 1;
282 continue;
283 }
284
285 let result = decoder.decode(block_id, &block, block_out_bytes);
286 assert!(result.is_ok());
287
288 block_id += 1;
289
290 match result.unwrap() {
291 WirehairResult::NeedMore => continue,
292 WirehairResult::Success => break,
293 _ => panic!(),
294 }
295 }
296
297 let mut decoded_message = [0u8; 500];
298
299 let result = decoder.recover(&mut decoded_message, 500);
300 assert!(result.is_ok());
301
302 assert!(wirehair_decoder_to_encoder(decoder).is_ok());
303 }
304}