wirehair_wrapper/
lib.rs

1pub mod wirehair {
2    use std::fmt::{Display, Error, Formatter};
3    use std::os::raw::{c_int, c_void};
4    use std::ptr::null;
5
6    #[repr(C)]
7    enum WirehairResultCode {
8        // Success code
9        Success = 0,
10        // More data is needed to decode.  This is normal and does not indicate a failure
11        NeedMore = 1,
12        // Other values are failure codes:
13        // A function parameter was invalid
14        InvalidInput = 2,
15        // Encoder needs a better dense seed
16        BadDenseSeed = 3,
17        // Encoder needs a better peel seed
18        BadPeelSeed = 4,
19        // N = ceil(messageBytes / blockBytes) is too small.
20        // Try reducing block_size or use a larger message
21        BadInputSmallN = 5,
22        // N = ceil(messageBytes / blockBytes) is too large.
23        // Try increasing block_size or use a smaller message
24        BadInputLargeN = 6,
25        // Not enough extra rows to solve it, must give up
26        ExtraInsufficient = 7,
27        // An error occurred during the request
28        Error = 8,
29        // Out of memory
30        OOM = 9,
31        // Platform is not supported yet
32        UnsupportedPlatform = 10,
33        Count,
34        /* for asserts */
35        Padding = 0x7fff_ffff,
36        /* int32_t padding */
37    }
38
39    #[link(name = "wirehair")]
40    extern "C" {
41        fn wirehair_init_(version: c_int) -> WirehairResultCode;
42        fn wirehair_encoder_create(
43            reuse_codec_opt: *const c_void,
44            message: *const u8,
45            message_size_bytes: u64,
46            block_size_bytes: u32,
47        ) -> *const c_void;
48        fn wirehair_encode(
49            codec: *const c_void,
50            block_id: u64,
51            block: *mut u8,
52            block_size: u32,
53            block_out_bytes: &mut u32,
54        ) -> WirehairResultCode;
55        fn wirehair_decoder_create(
56            reuse_codec_opt: *const c_void,
57            message_size_bytes: u64,
58            block_size_bytes: u32,
59        ) -> *const c_void;
60        fn wirehair_decode(
61            codec: *const c_void,
62            block_id: u64,
63            block: *const u8,
64            block_out_bytes: u32,
65        ) -> WirehairResultCode;
66        fn wirehair_recover(
67            codec: *const c_void,
68            message: *mut u8,
69            message_size_bytes: u64,
70        ) -> WirehairResultCode;
71        fn wirehair_decoder_becomes_encoder(codec: *const c_void) -> WirehairResultCode;
72        fn wirehair_free(codec: *const c_void) -> c_void;
73    }
74
75    #[derive(Debug, PartialEq)]
76    pub enum WirehairError {
77        InvalidInput,
78        BadDenseSeed,
79        BadPeelSeed,
80        BadInputSmallN,
81        BadInputLargeN,
82        ExtraInsufficient,
83        Error,
84        OOM,
85        UnsupportedPlatform,
86    }
87
88    impl Display for WirehairError {
89        fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
90            match *self {
91                WirehairError::InvalidInput => write!(f, "A function parameter was invalid"),
92                WirehairError::BadDenseSeed => write!(f, "Encoder needs a better dense seed"),
93                WirehairError::BadPeelSeed => write!(f, "Encoder needs a better peel seed"),
94                WirehairError::BadInputSmallN => write!(
95                    f,
96                    "Too less blocks! Try reducing block size or use a larger message"
97                ),
98                WirehairError::BadInputLargeN => write!(
99                    f,
100                    "Too many blocks! Try increasing block_size or use a smaller message"
101                ),
102                WirehairError::ExtraInsufficient => write!(
103                    f,
104                    "Not enough extra rows to solve it, possibly corrupted data"
105                ),
106                WirehairError::Error => write!(f, "Unexpected error"),
107                WirehairError::OOM => write!(f, "Out of memory"),
108                WirehairError::UnsupportedPlatform => write!(f, "Platform is not supported yet"),
109            }
110        }
111    }
112
113    #[derive(Debug, PartialEq)]
114    pub enum WirehairResult {
115        Success,
116        NeedMore,
117        Internal,
118    }
119
120    fn parse_wirehair_result(result: WirehairResultCode) -> Result<WirehairResult, WirehairError> {
121        match result {
122            WirehairResultCode::InvalidInput => Err(WirehairError::InvalidInput),
123            WirehairResultCode::BadDenseSeed => Err(WirehairError::BadDenseSeed),
124            WirehairResultCode::BadPeelSeed => Err(WirehairError::BadPeelSeed),
125            WirehairResultCode::BadInputSmallN => Err(WirehairError::BadInputSmallN),
126            WirehairResultCode::BadInputLargeN => Err(WirehairError::BadInputLargeN),
127            WirehairResultCode::ExtraInsufficient => Err(WirehairError::ExtraInsufficient),
128            WirehairResultCode::Error => Err(WirehairError::Error),
129            WirehairResultCode::OOM => Err(WirehairError::OOM),
130            WirehairResultCode::UnsupportedPlatform => Err(WirehairError::UnsupportedPlatform),
131            WirehairResultCode::Success => Ok(WirehairResult::Success),
132            WirehairResultCode::NeedMore => Ok(WirehairResult::NeedMore),
133            _ => Ok(WirehairResult::Internal),
134        }
135    }
136
137    pub fn wirehair_init() -> Result<(), WirehairError> {
138        let result = unsafe { parse_wirehair_result(wirehair_init_(2)) };
139        match result {
140            Ok(_r) => Ok(()),
141            Err(e) => Err(e),
142        }
143    }
144
145    pub fn wirehair_decoder_to_encoder(
146        decoder: WirehairDecoder,
147    ) -> Result<WirehairEncoder, WirehairError> {
148        let result = unsafe { wirehair_decoder_becomes_encoder(decoder.native_handler) };
149
150        match parse_wirehair_result(result) {
151            Ok(_) => Ok(WirehairEncoder {
152                native_handler: decoder.native_handler,
153            }),
154            Err(e) => Err(e),
155        }
156    }
157
158    pub struct WirehairEncoder {
159        native_handler: *const c_void,
160    }
161
162    impl WirehairEncoder {
163        pub fn new(
164            message: &[u8],
165            message_size_bytes: u64,
166            block_size_bytes: u32,
167        ) -> WirehairEncoder {
168            WirehairEncoder {
169                native_handler: unsafe {
170                    wirehair_encoder_create(
171                        null::<c_void>(),
172                        message.as_ptr(),
173                        message_size_bytes,
174                        block_size_bytes,
175                    )
176                },
177            }
178        }
179
180        pub fn encode(
181            &self,
182            block_id: u64,
183            block: &mut [u8],
184            block_size: u32,
185            block_out_bytes: &mut u32,
186        ) -> Result<WirehairResult, WirehairError> {
187            let result = unsafe {
188                wirehair_encode(
189                    self.native_handler,
190                    block_id,
191                    block.as_mut_ptr(),
192                    block_size,
193                    block_out_bytes,
194                )
195            };
196
197            parse_wirehair_result(result)
198        }
199    }
200
201    impl Drop for WirehairEncoder {
202        fn drop(&mut self) {
203            unsafe { wirehair_free(self.native_handler) };
204        }
205    }
206
207    pub struct WirehairDecoder {
208        native_handler: *const c_void,
209    }
210
211    impl WirehairDecoder {
212        pub fn new(message_size_bytes: u64, block_size_bytes: u32) -> WirehairDecoder {
213            WirehairDecoder {
214                native_handler: unsafe {
215                    wirehair_decoder_create(null::<c_void>(), message_size_bytes, block_size_bytes)
216                },
217            }
218        }
219
220        pub fn decode(
221            &self,
222            block_id: u64,
223            block: &[u8],
224            block_out_size_bytes: u32,
225        ) -> Result<WirehairResult, WirehairError> {
226            let result = unsafe {
227                wirehair_decode(
228                    self.native_handler,
229                    block_id,
230                    block.as_ptr(),
231                    block_out_size_bytes,
232                )
233            };
234
235            parse_wirehair_result(result)
236        }
237
238        pub fn recover(
239            &self,
240            message: &mut [u8],
241            message_size_bytes: u64,
242        ) -> Result<WirehairResult, WirehairError> {
243            let result = unsafe {
244                wirehair_recover(
245                    self.native_handler,
246                    message.as_mut_ptr(),
247                    message_size_bytes,
248                )
249            };
250
251            parse_wirehair_result(result)
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::wirehair::*;
259
260    #[test]
261    fn basic_flow_works() {
262        assert!(wirehair_init().is_ok());
263
264        let mut message = [0u8; 500];
265        for i in 0..500 {
266            message[i] = i as u8
267        }
268
269        let encoder = WirehairEncoder::new(&message, 500, 50);
270        let decoder = WirehairDecoder::new(500, 50);
271
272        let mut block_id = 0;
273
274        loop {
275            let mut block = [0u8; 50];
276            let mut block_out_bytes: u32 = 0;
277            let result = encoder.encode(block_id, &mut block, 50, &mut block_out_bytes);
278            assert!(result.is_ok());
279
280            if block_id % 5 == 0 {
281                block_id += 1;
282                continue;
283            }
284
285            let result = decoder.decode(block_id, &block, block_out_bytes);
286            assert!(result.is_ok());
287
288            block_id += 1;
289
290            match result.unwrap() {
291                WirehairResult::NeedMore => continue,
292                WirehairResult::Success => break,
293                _ => panic!(),
294            }
295        }
296
297        let mut decoded_message = [0u8; 500];
298
299        let result = decoder.recover(&mut decoded_message, 500);
300        assert!(result.is_ok());
301
302        assert!(wirehair_decoder_to_encoder(decoder).is_ok());
303    }
304}