servant_codec/aead_codec/
aead.rs1use {
4 bytes::{Buf, BufMut, Bytes, BytesMut},
5 futures_codec::{Decoder, Encoder},
6 ring::{
7 aead::{
8 Aad, Algorithm as AeadAlgorithm, BoundKey, Nonce, NonceSequence, OpeningKey,
9 SealingKey, UnboundKey, CHACHA20_POLY1305,
10 },
11 error::Unspecified,
12 hkdf::{Algorithm as HkdfAlgorithm, Salt, HKDF_SHA256},
13 rand::*,
14 },
15 std::{cell::RefCell, rc::Rc},
16};
17
18const APP_INFO: &[u8] = b"bee lib";
21
22struct Sequence(Vec<u8>);
25
26impl Sequence {
27 fn new(nonce_len: usize) -> Self {
28 Sequence(vec![0; nonce_len])
29 }
30}
31
32impl NonceSequence for Sequence {
33 fn advance(&mut self) -> Result<Nonce, Unspecified> {
34 assert_eq!(
35 std::mem::size_of::<u32>() + std::mem::size_of::<u64>(),
36 self.0.len()
37 );
38
39 let ptr = self.0.as_mut_ptr();
40 let p_u32 = ptr as *mut u32;
41 let p_u64: *mut u64;
42 unsafe {
43 p_u64 = ptr.offset(std::mem::size_of::<u32>() as isize) as *mut u64;
44 *p_u64 += 1;
45 if *p_u64 == u64::max_value() {
46 *p_u64 = 0;
47
48 *p_u32 += 1;
49 if *p_u32 == u32::max_value() {
50 *p_u32 = 0;
51 }
52 }
53 }
54 Nonce::try_assume_unique_for_key(&self.0)
55 }
56}
57
58#[derive(Clone)]
61pub struct Builder {
62 hkdf_algorithm: &'static HkdfAlgorithm,
63 aead_algorithm: &'static AeadAlgorithm,
64 salt_len: usize,
65 padding_len: u8,
66}
67
68impl Default for Builder {
69 fn default() -> Self {
70 Self {
71 hkdf_algorithm: &HKDF_SHA256,
72 aead_algorithm: &CHACHA20_POLY1305,
73 salt_len: 12,
74 padding_len: 128,
75 }
76 }
77}
78
79impl Builder {
80 pub fn set_hkdf_algorithm(&mut self, a: &'static HkdfAlgorithm) -> &mut Self {
81 self.hkdf_algorithm = a;
82 self
83 }
84 pub fn set_aead_algorithm(&mut self, a: &'static AeadAlgorithm) -> &mut Self {
85 self.aead_algorithm = a;
86 self
87 }
88 pub fn set_salt_len(&mut self, len: usize) -> &mut Self {
89 self.salt_len = len;
90 self
91 }
92 pub fn set_padding_len(&mut self, len: u8) -> &mut Self {
93 self.padding_len = len;
94 self
95 }
96 pub fn create(self, psk: &str) -> AeadCodec {
97 AeadCodec::new(self, psk)
98 }
99}
100
101#[derive(Clone)]
104pub struct AeadCodec {
105 builder: Rc<Builder>,
106 psk: Rc<String>,
107 sealing_key: Option<Rc<RefCell<SealingKey<Sequence>>>>,
108 opening_key: Option<Rc<RefCell<OpeningKey<Sequence>>>>,
109 body_len: Option<usize>,
110}
111
112impl AeadCodec {
113 fn new(builder: Builder, psk: &str) -> Self {
114 Self {
115 builder: Rc::new(builder),
116 psk: Rc::new(String::from(psk)),
117 sealing_key: None,
118 opening_key: None,
119 body_len: None,
120 }
121 }
122
123 fn get_padding(&self) -> Vec<u8> {
124 let rng = SystemRandom::new();
125 let mut padding_len = crate::utility::no_zero_rand_gen::<u8>(&rng);
126 padding_len %= self.builder.padding_len;
127 padding_len += 1;
128
129 let mut v = Vec::<u8>::new();
130 v.push(padding_len);
131 v.resize(padding_len as usize, 0);
132 v
133 }
134
135 fn derive_encode_key(&mut self, buf: &mut BytesMut) {
136 assert!(self.sealing_key.is_none());
137
138 let mut salt = vec![0_u8; self.builder.salt_len];
139 let rng = SystemRandom::new();
140 rng.fill(&mut salt).unwrap();
141 buf.put_slice(&salt);
142
143 let salt = Salt::new(*self.builder.hkdf_algorithm, &salt);
144 let prk = salt.extract(self.psk.as_bytes());
145 let okm = prk
146 .expand(&[APP_INFO], self.builder.aead_algorithm)
147 .unwrap();
148 let ubk = UnboundKey::from(okm);
149 let nonce_len = self.builder.aead_algorithm.nonce_len();
150 self.sealing_key = Some(Rc::new(RefCell::new(SealingKey::new(
151 ubk,
152 Sequence::new(nonce_len),
153 ))));
154 }
155
156 fn derive_decode_key(&mut self, buf: &mut BytesMut) {
157 assert!(self.opening_key.is_none());
158
159 let salt = buf.split_to(self.builder.salt_len);
160 let salt = Salt::new(*self.builder.hkdf_algorithm, &salt);
161 let prk = salt.extract(self.psk.as_bytes());
162 let okm = prk
163 .expand(&[APP_INFO], self.builder.aead_algorithm)
164 .unwrap();
165 let ubk = UnboundKey::from(okm);
166 let nonce_len = self.builder.aead_algorithm.nonce_len();
167 self.opening_key = Some(Rc::new(RefCell::new(OpeningKey::new(
168 ubk,
169 Sequence::new(nonce_len),
170 ))));
171 }
172
173 fn sealing_encode(&self, buf: &mut BytesMut) {
174 self.sealing_key
175 .as_ref()
176 .unwrap()
177 .borrow_mut()
178 .seal_in_place_append_tag(Aad::empty(), buf)
179 .unwrap();
180 }
181
182 fn opening_decode<'a>(&self, buf: &'a mut BytesMut) -> &'a mut [u8] {
183 self.opening_key
184 .as_ref()
185 .unwrap()
186 .borrow_mut()
187 .open_in_place(Aad::empty(), buf)
188 .unwrap()
189 }
190}
191
192impl Encoder for AeadCodec {
193 type Item = Bytes;
194 type Error = std::io::Error;
195
196 fn encode(&mut self, line: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
197 let mut body_len = line.len();
198 if body_len == 0 {
199 return Ok(());
200 }
201 if body_len > u32::max_value() as usize {
202 return Err(std::io::Error::new(
203 std::io::ErrorKind::InvalidData,
204 "Input line is too long.",
205 ));
206 }
207
208 let padding = self.get_padding();
209 body_len += padding.len();
210
211 let mut head_len = std::mem::size_of::<u32>();
212 let tag_len = self.builder.aead_algorithm.tag_len();
213 head_len += tag_len;
214 body_len += tag_len;
215 if self.sealing_key.is_none() {
216 let salt_len = self.builder.salt_len;
217 buf.reserve(salt_len + head_len + body_len);
218 self.derive_encode_key(buf);
219 } else {
220 buf.reserve(head_len + body_len);
221 }
222
223 let mut buf2 = buf.split_off(buf.len());
224 buf2.put_u32(body_len as u32);
225 self.sealing_encode(&mut buf2);
226 buf.extend_from_slice(&buf2);
227
228 buf2 = buf.split_off(buf.len());
229 buf2.put_slice(&line);
230 buf2.put_slice(&padding);
231 self.sealing_encode(&mut buf2);
232 buf.extend_from_slice(&buf2);
233
234 Ok(())
235 }
236}
237
238impl Decoder for AeadCodec {
239 type Item = Bytes;
240 type Error = std::io::Error;
241
242 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
243 if self.opening_key.is_none() {
244 let salt_len = self.builder.salt_len;
245 if buf.len() >= salt_len {
246 self.derive_decode_key(buf);
247 } else {
248 return Ok(None);
249 }
250 }
251
252 let tag_len = self.builder.aead_algorithm.tag_len();
253 let head_len = std::mem::size_of::<u32>() + tag_len;
254 if self.body_len == None && buf.len() >= head_len {
255 let mut head = buf.split_to(head_len);
256 let head = self.opening_decode(&mut head).to_vec();
257 let mut head = Bytes::from(head);
258 self.body_len = Some(head.get_u32() as usize);
259 }
260
261 if let Some(body_len) = self.body_len {
262 if buf.len() >= body_len {
263 self.body_len = None;
264 let mut body = buf.split_to(body_len);
265 let body = self.opening_decode(&mut body);
266 let padding_len = *body.iter().rev().find(|&&v| v > 0).unwrap();
267 let (body, _) = body.split_at(body.len() - (padding_len as usize));
268 Ok(Some(Bytes::from(body.to_vec())))
269 } else {
270 Ok(None)
271 }
272 } else {
273 Ok(None)
274 }
275 }
276}
277
278#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn it_works() {
286 assert_eq!(2 + 2, 4);
287
288 let h = b"hello";
289 let mut bytes = BytesMut::new();
290 println!("cap = {}", bytes.capacity());
291
292 bytes.put_slice(h);
293 println!("bytes = {:?}", bytes);
294
295 let e = bytes.iter().rev().find(|&s| *s == 101);
296 assert_eq!(*e.unwrap(), 101);
297 }
298
299 #[test]
300 fn t_get_padding() {
301 let builder = Builder::default();
302 let codec = builder.create("abcd");
303 println!("");
304 for _ in 0..8 {
305 let padding = codec.get_padding();
306 assert_eq!(padding[0] as usize, padding.len());
307 padding[1..].iter().for_each(|x| assert_eq!(*x, 0));
308 println!("padding = {:?}", padding);
309 }
310 }
311 #[test]
312 fn t_sequence() {
313 let nonce_len = 12_usize;
314 let mut seq = Sequence::new(nonce_len);
315 println!("");
316 for i in 1..11 {
317 seq.advance().unwrap();
318 println!("{:?}", seq.0);
319 let ptr = seq.0.as_ptr() as *const u8;
320 unsafe {
321 let ptr = ptr.offset(std::mem::size_of::<u32>() as isize) as *const usize;
322 println!("*ptr = {}", *ptr);
323 assert_eq!(i, *ptr);
324 };
325 }
326 }
327
328 use bytes::Bytes;
329 use futures::{executor, io::Cursor, sink::SinkExt, StreamExt, TryStreamExt};
330 use futures_codec::{BytesCodec, Decoder, Encoder, Framed, FramedRead, FramedWrite};
331 use std::io::Error;
332
333 use crate::{AeadCodecBuilder as Builder, AES_256_GCM, HKDF_SHA512};
334
335 #[test]
336 fn aead_test1() {
337 executor::block_on(async move {
338 let buf = b"Hello World!";
339 let mut framed = FramedRead::new(&buf[..], BytesCodec {});
340
341 let msg = framed.try_next().await.unwrap().unwrap();
342 println!("msg: {:?}", msg);
343 assert_eq!(msg, Bytes::from(&buf[..]));
344 });
345
346 let psk = "thisispsk";
347
348 let mut builder = Builder::default();
349 executor::block_on(codec(
350 builder.clone().create(psk),
351 builder.clone().create(psk),
352 ));
353
354 builder
355 .set_salt_len(24)
356 .set_padding_len(64)
357 .set_aead_algorithm(&AES_256_GCM)
358 .set_hkdf_algorithm(&HKDF_SHA512);
359 executor::block_on(codec2(
360 builder.clone().create(psk),
361 builder.clone().create(psk),
362 ));
363 }
364
365 async fn codec<T>(c: T, d: T)
366 where
367 T: Encoder<Item = Bytes, Error = Error> + Decoder<Item = Bytes, Error = Error> + Clone,
368 {
369 let mut buf = vec![];
370 let cur = Cursor::new(&mut buf);
371 let mut framed = FramedWrite::new(cur, c);
372
373 let mut i = 0_usize;
374 while {
375 i += 1;
376 let msg = Bytes::from(format!("Hello World! #{}", i));
377 framed.send(msg.clone()).await.unwrap();
378
379 i < 88
380 } {}
381 println!("buf: {:?}", buf);
382
383 i = 0;
384 let mut framed2 = FramedRead::new(&buf[..], d);
385 while let Some(msg2) = framed2.next().await {
386 let msg2 = msg2.unwrap();
387 println!("msg: {:?}", msg2);
388
389 i += 1;
390 assert_eq!(msg2, Bytes::from(format!("Hello World! #{}", i)));
391 }
392 }
393
394 async fn codec2<T>(c: T, d: T)
395 where
396 T: Encoder<Item = Bytes, Error = Error> + Decoder<Item = Bytes, Error = Error> + Clone,
397 {
398 let mut buf = vec![];
399 let cur = Cursor::new(&mut buf);
400 let mut framed = Framed::new(cur, c);
401
402 let mut i = 0_usize;
403 while {
404 i += 1;
405 let msg = Bytes::from(format!("Hello Customer! #{}", i));
406 framed.send(msg.clone()).await.unwrap();
407
408 i < 68
409 } {}
410 println!("buf: {:?}", buf);
411
412 i = 0;
413 let cur = Cursor::new(&mut buf);
414 let mut framed2 = Framed::new(cur, d);
415 while let Some(msg2) = framed2.try_next().await.unwrap() {
416 println!("msg: {:?}", msg2);
417
418 i += 1;
419 assert_eq!(msg2, Bytes::from(format!("Hello Customer! #{}", i)));
420 }
421 }
422}