servant_codec/aead_codec/
aead.rs

1// -- aead.rs --
2
3use {
4    bytes::{Buf, BufMut, Bytes, BytesMut},
5    futures_codec::{Decoder, Encoder},
6    ring::{
7        aead::{
8            Aad, Algorithm as AeadAlgorithm, BoundKey, Nonce, NonceSequence, OpeningKey,
9            SealingKey, UnboundKey, CHACHA20_POLY1305,
10        },
11        error::Unspecified,
12        hkdf::{Algorithm as HkdfAlgorithm, Salt, HKDF_SHA256},
13        rand::*,
14    },
15    std::{cell::RefCell, rc::Rc},
16};
17
18// --
19
20const APP_INFO: &[u8] = b"bee lib";
21
22// --
23
24struct Sequence(Vec<u8>);
25
26impl Sequence {
27    fn new(nonce_len: usize) -> Self {
28        Sequence(vec![0; nonce_len])
29    }
30}
31
32impl NonceSequence for Sequence {
33    fn advance(&mut self) -> Result<Nonce, Unspecified> {
34        assert_eq!(
35            std::mem::size_of::<u32>() + std::mem::size_of::<u64>(),
36            self.0.len()
37        );
38
39        let ptr = self.0.as_mut_ptr();
40        let p_u32 = ptr as *mut u32;
41        let p_u64: *mut u64;
42        unsafe {
43            p_u64 = ptr.offset(std::mem::size_of::<u32>() as isize) as *mut u64;
44            *p_u64 += 1;
45            if *p_u64 == u64::max_value() {
46                *p_u64 = 0;
47
48                *p_u32 += 1;
49                if *p_u32 == u32::max_value() {
50                    *p_u32 = 0;
51                }
52            }
53        }
54        Nonce::try_assume_unique_for_key(&self.0)
55    }
56}
57
58// --
59
60#[derive(Clone)]
61pub struct Builder {
62    hkdf_algorithm: &'static HkdfAlgorithm,
63    aead_algorithm: &'static AeadAlgorithm,
64    salt_len: usize,
65    padding_len: u8,
66}
67
68impl Default for Builder {
69    fn default() -> Self {
70        Self {
71            hkdf_algorithm: &HKDF_SHA256,
72            aead_algorithm: &CHACHA20_POLY1305,
73            salt_len: 12,
74            padding_len: 128,
75        }
76    }
77}
78
79impl Builder {
80    pub fn set_hkdf_algorithm(&mut self, a: &'static HkdfAlgorithm) -> &mut Self {
81        self.hkdf_algorithm = a;
82        self
83    }
84    pub fn set_aead_algorithm(&mut self, a: &'static AeadAlgorithm) -> &mut Self {
85        self.aead_algorithm = a;
86        self
87    }
88    pub fn set_salt_len(&mut self, len: usize) -> &mut Self {
89        self.salt_len = len;
90        self
91    }
92    pub fn set_padding_len(&mut self, len: u8) -> &mut Self {
93        self.padding_len = len;
94        self
95    }
96    pub fn create(self, psk: &str) -> AeadCodec {
97        AeadCodec::new(self, psk)
98    }
99}
100
101// --
102
103#[derive(Clone)]
104pub struct AeadCodec {
105    builder: Rc<Builder>,
106    psk: Rc<String>,
107    sealing_key: Option<Rc<RefCell<SealingKey<Sequence>>>>,
108    opening_key: Option<Rc<RefCell<OpeningKey<Sequence>>>>,
109    body_len: Option<usize>,
110}
111
112impl AeadCodec {
113    fn new(builder: Builder, psk: &str) -> Self {
114        Self {
115            builder: Rc::new(builder),
116            psk: Rc::new(String::from(psk)),
117            sealing_key: None,
118            opening_key: None,
119            body_len: None,
120        }
121    }
122
123    fn get_padding(&self) -> Vec<u8> {
124        let rng = SystemRandom::new();
125        let mut padding_len = crate::utility::no_zero_rand_gen::<u8>(&rng);
126        padding_len %= self.builder.padding_len;
127        padding_len += 1;
128
129        let mut v = Vec::<u8>::new();
130        v.push(padding_len);
131        v.resize(padding_len as usize, 0);
132        v
133    }
134
135    fn derive_encode_key(&mut self, buf: &mut BytesMut) {
136        assert!(self.sealing_key.is_none());
137
138        let mut salt = vec![0_u8; self.builder.salt_len];
139        let rng = SystemRandom::new();
140        rng.fill(&mut salt).unwrap();
141        buf.put_slice(&salt);
142
143        let salt = Salt::new(*self.builder.hkdf_algorithm, &salt);
144        let prk = salt.extract(self.psk.as_bytes());
145        let okm = prk
146            .expand(&[APP_INFO], self.builder.aead_algorithm)
147            .unwrap();
148        let ubk = UnboundKey::from(okm);
149        let nonce_len = self.builder.aead_algorithm.nonce_len();
150        self.sealing_key = Some(Rc::new(RefCell::new(SealingKey::new(
151            ubk,
152            Sequence::new(nonce_len),
153        ))));
154    }
155
156    fn derive_decode_key(&mut self, buf: &mut BytesMut) {
157        assert!(self.opening_key.is_none());
158
159        let salt = buf.split_to(self.builder.salt_len);
160        let salt = Salt::new(*self.builder.hkdf_algorithm, &salt);
161        let prk = salt.extract(self.psk.as_bytes());
162        let okm = prk
163            .expand(&[APP_INFO], self.builder.aead_algorithm)
164            .unwrap();
165        let ubk = UnboundKey::from(okm);
166        let nonce_len = self.builder.aead_algorithm.nonce_len();
167        self.opening_key = Some(Rc::new(RefCell::new(OpeningKey::new(
168            ubk,
169            Sequence::new(nonce_len),
170        ))));
171    }
172
173    fn sealing_encode(&self, buf: &mut BytesMut) {
174        self.sealing_key
175            .as_ref()
176            .unwrap()
177            .borrow_mut()
178            .seal_in_place_append_tag(Aad::empty(), buf)
179            .unwrap();
180    }
181
182    fn opening_decode<'a>(&self, buf: &'a mut BytesMut) -> &'a mut [u8] {
183        self.opening_key
184            .as_ref()
185            .unwrap()
186            .borrow_mut()
187            .open_in_place(Aad::empty(), buf)
188            .unwrap()
189    }
190}
191
192impl Encoder for AeadCodec {
193    type Item = Bytes;
194    type Error = std::io::Error;
195
196    fn encode(&mut self, line: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
197        let mut body_len = line.len();
198        if body_len == 0 {
199            return Ok(());
200        }
201        if body_len > u32::max_value() as usize {
202            return Err(std::io::Error::new(
203                std::io::ErrorKind::InvalidData,
204                "Input line is too long.",
205            ));
206        }
207
208        let padding = self.get_padding();
209        body_len += padding.len();
210
211        let mut head_len = std::mem::size_of::<u32>();
212        let tag_len = self.builder.aead_algorithm.tag_len();
213        head_len += tag_len;
214        body_len += tag_len;
215        if self.sealing_key.is_none() {
216            let salt_len = self.builder.salt_len;
217            buf.reserve(salt_len + head_len + body_len);
218            self.derive_encode_key(buf);
219        } else {
220            buf.reserve(head_len + body_len);
221        }
222
223        let mut buf2 = buf.split_off(buf.len());
224        buf2.put_u32(body_len as u32);
225        self.sealing_encode(&mut buf2);
226        buf.extend_from_slice(&buf2);
227
228        buf2 = buf.split_off(buf.len());
229        buf2.put_slice(&line);
230        buf2.put_slice(&padding);
231        self.sealing_encode(&mut buf2);
232        buf.extend_from_slice(&buf2);
233
234        Ok(())
235    }
236}
237
238impl Decoder for AeadCodec {
239    type Item = Bytes;
240    type Error = std::io::Error;
241
242    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
243        if self.opening_key.is_none() {
244            let salt_len = self.builder.salt_len;
245            if buf.len() >= salt_len {
246                self.derive_decode_key(buf);
247            } else {
248                return Ok(None);
249            }
250        }
251
252        let tag_len = self.builder.aead_algorithm.tag_len();
253        let head_len = std::mem::size_of::<u32>() + tag_len;
254        if self.body_len == None && buf.len() >= head_len {
255            let mut head = buf.split_to(head_len);
256            let head = self.opening_decode(&mut head).to_vec();
257            let mut head = Bytes::from(head);
258            self.body_len = Some(head.get_u32() as usize);
259        }
260
261        if let Some(body_len) = self.body_len {
262            if buf.len() >= body_len {
263                self.body_len = None;
264                let mut body = buf.split_to(body_len);
265                let body = self.opening_decode(&mut body);
266                let padding_len = *body.iter().rev().find(|&&v| v > 0).unwrap();
267                let (body, _) = body.split_at(body.len() - (padding_len as usize));
268                Ok(Some(Bytes::from(body.to_vec())))
269            } else {
270                Ok(None)
271            }
272        } else {
273            Ok(None)
274        }
275    }
276}
277
278// --
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn it_works() {
286        assert_eq!(2 + 2, 4);
287
288        let h = b"hello";
289        let mut bytes = BytesMut::new();
290        println!("cap = {}", bytes.capacity());
291
292        bytes.put_slice(h);
293        println!("bytes = {:?}", bytes);
294
295        let e = bytes.iter().rev().find(|&s| *s == 101);
296        assert_eq!(*e.unwrap(), 101);
297    }
298
299    #[test]
300    fn t_get_padding() {
301        let builder = Builder::default();
302        let codec = builder.create("abcd");
303        println!("");
304        for _ in 0..8 {
305            let padding = codec.get_padding();
306            assert_eq!(padding[0] as usize, padding.len());
307            padding[1..].iter().for_each(|x| assert_eq!(*x, 0));
308            println!("padding = {:?}", padding);
309        }
310    }
311    #[test]
312    fn t_sequence() {
313        let nonce_len = 12_usize;
314        let mut seq = Sequence::new(nonce_len);
315        println!("");
316        for i in 1..11 {
317            seq.advance().unwrap();
318            println!("{:?}", seq.0);
319            let ptr = seq.0.as_ptr() as *const u8;
320            unsafe {
321                let ptr = ptr.offset(std::mem::size_of::<u32>() as isize) as *const usize;
322                println!("*ptr = {}", *ptr);
323                assert_eq!(i, *ptr);
324            };
325        }
326    }
327
328    use bytes::Bytes;
329    use futures::{executor, io::Cursor, sink::SinkExt, StreamExt, TryStreamExt};
330    use futures_codec::{BytesCodec, Decoder, Encoder, Framed, FramedRead, FramedWrite};
331    use std::io::Error;
332
333    use crate::{AeadCodecBuilder as Builder, AES_256_GCM, HKDF_SHA512};
334
335    #[test]
336    fn aead_test1() {
337        executor::block_on(async move {
338            let buf = b"Hello World!";
339            let mut framed = FramedRead::new(&buf[..], BytesCodec {});
340
341            let msg = framed.try_next().await.unwrap().unwrap();
342            println!("msg: {:?}", msg);
343            assert_eq!(msg, Bytes::from(&buf[..]));
344        });
345
346        let psk = "thisispsk";
347
348        let mut builder = Builder::default();
349        executor::block_on(codec(
350            builder.clone().create(psk),
351            builder.clone().create(psk),
352        ));
353
354        builder
355            .set_salt_len(24)
356            .set_padding_len(64)
357            .set_aead_algorithm(&AES_256_GCM)
358            .set_hkdf_algorithm(&HKDF_SHA512);
359        executor::block_on(codec2(
360            builder.clone().create(psk),
361            builder.clone().create(psk),
362        ));
363    }
364
365    async fn codec<T>(c: T, d: T)
366    where
367        T: Encoder<Item = Bytes, Error = Error> + Decoder<Item = Bytes, Error = Error> + Clone,
368    {
369        let mut buf = vec![];
370        let cur = Cursor::new(&mut buf);
371        let mut framed = FramedWrite::new(cur, c);
372
373        let mut i = 0_usize;
374        while {
375            i += 1;
376            let msg = Bytes::from(format!("Hello World! #{}", i));
377            framed.send(msg.clone()).await.unwrap();
378
379            i < 88
380        } {}
381        println!("buf: {:?}", buf);
382
383        i = 0;
384        let mut framed2 = FramedRead::new(&buf[..], d);
385        while let Some(msg2) = framed2.next().await {
386            let msg2 = msg2.unwrap();
387            println!("msg: {:?}", msg2);
388
389            i += 1;
390            assert_eq!(msg2, Bytes::from(format!("Hello World! #{}", i)));
391        }
392    }
393
394    async fn codec2<T>(c: T, d: T)
395    where
396        T: Encoder<Item = Bytes, Error = Error> + Decoder<Item = Bytes, Error = Error> + Clone,
397    {
398        let mut buf = vec![];
399        let cur = Cursor::new(&mut buf);
400        let mut framed = Framed::new(cur, c);
401
402        let mut i = 0_usize;
403        while {
404            i += 1;
405            let msg = Bytes::from(format!("Hello Customer! #{}", i));
406            framed.send(msg.clone()).await.unwrap();
407
408            i < 68
409        } {}
410        println!("buf: {:?}", buf);
411
412        i = 0;
413        let cur = Cursor::new(&mut buf);
414        let mut framed2 = Framed::new(cur, d);
415        while let Some(msg2) = framed2.try_next().await.unwrap() {
416            println!("msg: {:?}", msg2);
417
418            i += 1;
419            assert_eq!(msg2, Bytes::from(format!("Hello Customer! #{}", i)));
420        }
421    }
422}