1use bytes::{BufMut, Bytes, BytesMut};
2
3use futures::ready;
4use ring::aead::{Aad, BoundKey, OpeningKey, SealingKey, UnboundKey, AES_256_GCM};
5
6use core::slice;
7use std::io::{self, ErrorKind};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tracing::trace;
12
13use super::kind::CipherKind;
14use super::util;
15
16enum EncryptWriteState {
17 AssemblePacket,
18 Writing { pos: usize },
19}
20
21pub struct EncryptedWriter {
22 sealing_key: Option<SealingKey<util::NonceSequence>>, buf: BytesMut,
24 state: EncryptWriteState,
25 kind: CipherKind,
26}
27
28impl EncryptedWriter {
29 pub fn new(kind: CipherKind, key: &[u8], salt: &[u8]) -> Self {
30 match kind {
31 CipherKind::AES_256_GCM => {
32 let mut buf = BytesMut::with_capacity(salt.len());
33 buf.put(salt);
34
35 let sub_key = util::hkdf_sha1(key, salt);
37
38 let unbound =
39 UnboundKey::new(&AES_256_GCM, &sub_key).expect("key.len != algorithm.key_len");
40 let sealing_key = SealingKey::new(unbound, util::NonceSequence::new());
41
42 Self {
43 sealing_key: Some(sealing_key),
44 buf,
45 state: EncryptWriteState::AssemblePacket,
46 kind,
47 }
48 }
49 _ => panic!("unsupport chipher kind"),
50 }
51 }
52
53 pub fn poll_write<S>(
62 &mut self,
63 cx: &mut Context,
64 stream: &mut S,
65 mut buf: &[u8],
66 ) -> Poll<io::Result<usize>>
67 where
68 S: AsyncWrite + Unpin + ?Sized,
69 {
70 if buf.len() > self.kind.max_package_size() {
71 buf = &buf[..self.kind.max_package_size()]
72 }
73
74 loop {
75 match self.state {
76 EncryptWriteState::AssemblePacket => {
77 let befor_len = self.buf.len(); let length_size = 2;
80 self.buf.reserve(length_size);
81 self.buf.put_u16(buf.len() as u16);
82 let view = &mut self.buf.as_mut()[befor_len..];
83 debug_assert!(view.len() == length_size);
84 let tag = self
85 .sealing_key
86 .as_mut()
87 .unwrap()
88 .seal_in_place_separate_tag(Aad::<[u8; 0]>::empty(), view)
89 .expect("seal_in_place_separate_tag for length");
90 self.buf.extend_from_slice(tag.as_ref());
91
92 let befor_len = self.buf.len(); self.buf.extend_from_slice(buf);
95 let view = &mut self.buf.as_mut()[befor_len..];
96 let tag = self
97 .sealing_key
98 .as_mut()
99 .unwrap()
100 .seal_in_place_separate_tag(Aad::<[u8; 0]>::empty(), view)
101 .expect("seal_in_place_separate_tag for data");
102 self.buf.extend_from_slice(tag.as_ref());
103
104 self.state = EncryptWriteState::Writing { pos: 0 };
106 }
107 EncryptWriteState::Writing { ref mut pos } => {
108 while *pos < self.buf.len() {
109 let n = ready!(Pin::new(&mut *stream).poll_write(cx, &self.buf[*pos..]))?;
110 *pos += n;
111 }
112
113 self.state = EncryptWriteState::AssemblePacket;
115 self.buf.clear();
116 return Ok(buf.len()).into();
117 }
118 }
119 }
120 }
121}
122
123enum DecryptReadState {
124 WaitSalt,
125 ReadLength,
126 ReadData { length: usize },
127 BufferedData { pos: usize },
128}
129pub struct DecryptedReader {
130 opening_key: Option<OpeningKey<util::NonceSequence>>, buf: BytesMut,
132 state: DecryptReadState,
133 kind: CipherKind,
134 salt: Option<Bytes>,
135 key: Bytes,
136}
137
138impl DecryptedReader {
139 pub fn new(kind: CipherKind, key: &[u8]) -> Self {
140 match kind {
141 CipherKind::AES_256_GCM => Self {
142 opening_key: None,
143 buf: BytesMut::new(),
144 state: DecryptReadState::WaitSalt,
145 kind,
146 salt: None,
147 key: Bytes::copy_from_slice(key),
148 },
149 _ => panic!("unsupport chipher kind"),
150 }
151 }
152
153 pub fn poll_read<S>(
154 &mut self,
155 cx: &mut Context,
156 stream: &mut S,
157 buf: &mut ReadBuf,
158 ) -> Poll<io::Result<()>>
159 where
160 S: AsyncRead + Unpin + ?Sized,
161 {
162 loop {
163 match self.state {
164 DecryptReadState::WaitSalt => {
165 let salt_len = self.kind.salt_len();
166 let n = ready!(self.poll_read_exact_or_zero(cx, stream, salt_len))?;
167 if n == 0 {
168 return Err(ErrorKind::UnexpectedEof.into()).into();
169 }
170 debug_assert!(self.buf.len() == salt_len);
171 self.salt = Some(Bytes::copy_from_slice(&self.buf));
172
173 let sub_key = util::hkdf_sha1(&self.key, &self.salt.as_ref().unwrap());
175 trace!("peer sub_key is {:?}", sub_key);
176
177 let unbound = UnboundKey::new(&AES_256_GCM, &sub_key)
178 .expect("key.len != algorithm.key_len");
179 let opening_key = OpeningKey::new(unbound, util::NonceSequence::new());
180
181 self.buf.clear();
182 self.state = DecryptReadState::ReadLength;
183 self.buf.reserve(2 + self.kind.tag_len());
184 self.opening_key = Some(opening_key);
185 }
186 DecryptReadState::ReadLength => {
187 let usize =
188 ready!(self.poll_read_exact_or_zero(cx, stream, 2 + self.kind.tag_len()))?;
189 if usize == 0 {
190 return Ok(()).into();
191 } else {
192 let result = self
193 .opening_key
194 .as_mut()
195 .unwrap()
196 .open_in_place(Aad::<[u8; 0]>::empty(), &mut self.buf)
197 .map_err(|_| {
198 io::Error::new(ErrorKind::Other, "ReadLength invalid tag-in")
199 })?;
200 let plen = u16::from_be_bytes([result[0], result[1]]) as usize;
201 if plen > self.kind.max_package_size() {
202 let err = io::Error::new(
203 ErrorKind::InvalidData,
204 format!(
205 "buffer size too large ({:#x}), AEAD encryption protocol requires buffer to be smaller than 0x3FFF, the higher two bits must be set to zero",
206 plen
207 ),
208 );
209 return Err(err).into();
210 }
211 self.buf.clear();
212 self.state = DecryptReadState::ReadData { length: plen };
213 self.buf.reserve(plen + self.kind.tag_len())
214 }
215 }
216 DecryptReadState::ReadData { length } => {
217 let data_len = length + self.kind.tag_len();
218 let n = ready!(self.poll_read_exact_or_zero(cx, stream, data_len))?;
219 if n == 0 {
220 return Err(ErrorKind::UnexpectedEof.into()).into();
221 }
222 debug_assert_eq!(data_len, self.buf.len());
223
224 let _ = self
225 .opening_key
226 .as_mut()
227 .unwrap()
228 .open_in_place(Aad::<[u8; 0]>::empty(), &mut self.buf)
229 .map_err(|_| io::Error::new(ErrorKind::Other, "ReadData invalid tag-in"))?;
230
231 self.buf.truncate(length);
233 self.state = DecryptReadState::BufferedData { pos: 0 };
234 }
235 DecryptReadState::BufferedData { ref mut pos } => {
236 if *pos < self.buf.len() {
237 let buffered = &self.buf[*pos..];
238 let consumed = usize::min(buffered.len(), buf.remaining());
239 buf.put_slice(&buffered[..consumed]);
240 *pos += consumed;
241
242 return Ok(()).into();
243 }
244 self.buf.clear();
245 self.state = DecryptReadState::ReadLength;
246 self.buf.reserve(2 + self.kind.tag_len());
247 }
248 }
249 }
250 }
251
252 fn poll_read_exact_or_zero<S>(
253 &mut self,
254 cx: &mut Context,
255 stream: &mut S,
256 size: usize,
257 ) -> Poll<io::Result<usize>>
258 where
259 S: AsyncRead + Unpin + ?Sized,
260 {
261 assert!(size != 0);
262 while self.buf.len() < size {
263 let remaing = size - self.buf.len();
264
265 let view = &mut self.buf.chunk_mut()[..remaing];
266 debug_assert_eq!(view.len(), remaing);
267 let mut read_buf = ReadBuf::uninit(unsafe {
268 slice::from_raw_parts_mut(view.as_mut_ptr() as *mut _, remaing)
269 });
270
271 ready!(Pin::new(&mut *stream).poll_read(cx, &mut read_buf))?;
272 let n = read_buf.filled().len();
273
274 unsafe { self.buf.advance_mut(n) }
275
276 if n == 0 {
277 if !self.buf.is_empty() {
278 return Err(ErrorKind::UnexpectedEof.into()).into();
279 } else {
280 return Ok(0).into();
281 }
282 }
283 }
284 Ok(size).into()
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use std::{
291 ops::DerefMut,
292 pin::Pin,
293 task::{Context, Poll},
294 };
295
296 use futures::{ready, Future};
297 use tokio::io::ReadBuf;
298
299 use crate::crypto::{aead::DecryptedReader, kind::CipherKind, util};
300
301 use super::EncryptedWriter;
302
303 #[tokio::test]
304 async fn test_reader_writer() {
305 let pwd = "123456";
306 let salt = &[0u8; 32];
307 let key = util::evp_bytes_to_key(pwd.as_bytes(), CipherKind::AES_256_GCM.key_len());
308
309 struct Fut {
310 r: DecryptedReader,
311 w: EncryptedWriter,
312 mock: Vec<u8>,
313 }
314
315 impl Future for Fut {
316 type Output = ();
317 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318 let content = "hello";
319
320 let p = self.deref_mut();
321
322 let w = &mut p.w;
323 let mock = &mut p.mock;
324
325 let n = ready!(w.poll_write(cx, mock, content.as_bytes())).unwrap();
326 assert_eq!(n, content.len());
327
328 let r = &mut p.r;
329 let mut bs = [0u8; 1024];
330 let mut buf = ReadBuf::new(&mut bs);
331 ready!(r.poll_read(cx, &mut mock.as_slice(), &mut buf)).unwrap();
332
333 assert_eq!(buf.filled(), content.as_bytes());
334
335 ().into()
336 }
337 }
338
339 Fut {
340 r: DecryptedReader::new(CipherKind::AES_256_GCM, &key),
341 w: EncryptedWriter::new(CipherKind::AES_256_GCM, &key, salt),
342 mock: Vec::<u8>::new(),
343 }
344 .await
345 }
346}