1use futures_core::Stream;
2use bytes::{Bytes, BytesMut, BufMut};
3use ctr::cipher::{KeyIvInit, StreamCipher};
4use hmac::{Mac, KeyInit};
5use constant_time_eq::{constant_time_eq_32, constant_time_eq_64};
6use core::pin::Pin;
7use core::fmt::Display;
8use core::error::Error;
9use core::task::{Context, Poll, ready};
10use crate::util::{HmacSha3_256, kdf, compute_verification_hash};
11use crate::{BYTES_PER_POLL, Aes256Ctr};
12
13pin_project_lite::pin_project! {
14 pub struct Decrypt<R> {
15 #[pin]
16 read: Option<R>,
17 buffer: Option<BytesMut>
18 }
19}
20
21impl<R> Decrypt<R> {
22 pub fn new(read: R) -> Self {
23 Self {
24 read: Some(read),
25 buffer: Some(BytesMut::new())
26 }
27 }
28}
29
30#[derive(Debug)]
31pub enum SsecHeaderError<E> {
32 NotSsec,
33 UnsupportedVersion(u8),
34 UnsupportedCompression(u8),
35 Stream(E)
36}
37
38impl<E: Display> Display for SsecHeaderError<E> {
39 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40 match self {
41 Self::NotSsec => write!(f, "wrapped stream did not produce a SSEC file"),
42 Self::UnsupportedVersion(v) => write!(f, "SSEC file version {v:?} is unsupported"),
43 Self::UnsupportedCompression(c) => write!(f, "SSEC compression algorithm {c:?} is valid but currently unsupported"),
44 Self::Stream(e) => e.fmt(f)
45 }
46 }
47}
48
49impl<E> Error for SsecHeaderError<E>
50where
51 E: Error + 'static,
52 Self: Display
53{
54 #[inline]
55 fn source(&self) -> Option<&(dyn Error + 'static)> {
56 match self {
57 Self::NotSsec => None,
58 Self::UnsupportedVersion(_) => None,
59 Self::UnsupportedCompression(_) => None,
60 Self::Stream(e) => Some(e)
61 }
62 }
63}
64
65const HEADER_LEN: usize = 118;
66
67impl<E, R: Stream<Item = Result<Bytes, E>> + Unpin> Future for Decrypt<R> {
68 type Output = Result<Box<DecryptAwaitingPassword<R>>, SsecHeaderError<E>>;
69
70 fn poll(
71 self: Pin<&mut Self>,
72 cx: &mut Context<'_>
73 ) -> Poll<Self::Output> {
74 let this = self.project();
75
76 if this.buffer.as_ref().unwrap().len() >= HEADER_LEN {
77 let read = this.read.get_mut().take().unwrap();
78 let mut buffer = this.buffer.take().unwrap();
79
80 let header = buffer.split_to(HEADER_LEN);
81
82 if &header[0..=3] != b"SSEC" {
83 return Poll::Ready(Err(SsecHeaderError::NotSsec));
84 }
85
86 if header[4] != 0x01 {
87 return Poll::Ready(Err(SsecHeaderError::UnsupportedVersion(header[4])));
88 }
89
90 if header[5] != 0x6e {
91 return Poll::Ready(Err(match header[5] {
92 0x62 => SsecHeaderError::UnsupportedCompression(0x62),
93 _ => SsecHeaderError::NotSsec
94 }));
95 }
96
97 let salt: [u8; 32] = header[6..=37].try_into().unwrap();
98 let verification_hash: [u8; 64] = header[38..=101].try_into().unwrap();
99 let iv: [u8; 16] = header[102..HEADER_LEN].try_into().unwrap();
100
101 Poll::Ready(Ok(Box::new(DecryptAwaitingPassword {
102 read,
103 buffer,
104 salt,
105 verification_hash,
106 iv,
107 version_byte: header[4],
108 compression_algo: header[5]
109 })))
110 } else {
111 let read = this.read.as_pin_mut().unwrap();
112
113 match read.poll_next(cx) {
114 Poll::Ready(Some(Ok(bytes))) => {
115 this.buffer.as_mut().unwrap().put(bytes);
116 cx.waker().wake_by_ref();
117 Poll::Pending
118 },
119 Poll::Ready(Some(Err(e))) => Poll::Ready(Err(SsecHeaderError::Stream(e))),
120 Poll::Ready(None) => Poll::Ready(Err(SsecHeaderError::NotSsec)),
121 Poll::Pending => Poll::Pending
122 }
123 }
124 }
125}
126
127pub struct DecryptAwaitingPassword<R> {
128 read: R,
129 buffer: BytesMut,
130 salt: [u8; 32],
131 verification_hash: [u8; 64],
132 iv: [u8; 16],
133 version_byte: u8,
134 compression_algo: u8
135}
136
137const HMAC_LEN: usize = 32;
138
139impl<R> DecryptAwaitingPassword<R> {
140 pub fn try_password(mut self: Box<Self>, password: &[u8]) -> Result<DecryptStream<R>, Box<Self>> {
147 let key = kdf(password, &self.salt);
148
149 if constant_time_eq_64(compute_verification_hash(&key).as_ref(), &self.verification_hash) {
150 let mut integrity_code = HmacSha3_256::new_from_slice(key.as_ref().get_ref()).unwrap();
151 integrity_code.update(&[self.version_byte, self.compression_algo]);
152 integrity_code.update(&self.iv);
153
154 let buf_len = self.buffer.len();
155 let eof_buf = if buf_len >= HMAC_LEN {
156 self.buffer.split_off(buf_len - HMAC_LEN)
157 } else {
158 self.buffer.split()
159 };
160 debug_assert!(eof_buf.len() <= HMAC_LEN);
161
162 let state = DecryptState::PostHeader(Box::new(DecryptionState {
163 aes: Aes256Ctr::new(key.as_ref().get_ref().into(), (&self.iv).into()),
164 integrity_code: Some(integrity_code),
165 eof: false,
166 eof_buf
167 }));
168
169 Ok(DecryptStream {
170 read: self.read,
171 state,
172 buffer: self.buffer
173 })
174 } else {
175 Err(self)
176 }
177 }
178}
179
180struct DecryptionState {
181 aes: Aes256Ctr,
182 integrity_code: Option<HmacSha3_256>,
183 eof: bool,
184 eof_buf: BytesMut
185}
186
187enum DecryptState {
188 PostHeader(Box<DecryptionState>),
189 Done
190}
191
192pin_project_lite::pin_project! {
193 pub struct DecryptStream<R> {
194 #[pin]
195 read: R,
196 state: DecryptState,
197 buffer: BytesMut,
198 }
199}
200
201#[derive(Debug)]
202pub enum DecryptStreamError<E> {
203 TooShort,
204 IntegrityFailed,
205 Stream(E)
206}
207
208impl<E: Display> Display for DecryptStreamError<E> {
209 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
210 match self {
211 Self::TooShort => write!(f, "wrapped stream was too short to have been a valid SSEC file (no integrity code)"),
212 Self::IntegrityFailed => write!(f, "the file has been tampered with, previously decrypted data is inauthentic and should be discarded"),
213 Self::Stream(e) => e.fmt(f)
214 }
215 }
216}
217
218impl<E> Error for DecryptStreamError<E>
219where
220 E: Error + 'static,
221 Self: Display
222{
223 #[inline]
224 fn source(&self) -> Option<&(dyn Error + 'static)> {
225 match self {
226 Self::TooShort => None,
227 Self::IntegrityFailed => None,
228 Self::Stream(e) => Some(e)
229 }
230 }
231}
232
233impl<E, R> Stream for DecryptStream<R>
234where
235 R: Stream<Item = Result<Bytes, E>>
236{
237 type Item = Result<Bytes, DecryptStreamError<E>>;
238
239 fn poll_next(
240 self: Pin<&mut Self>,
241 cx: &mut Context<'_>
242 ) -> Poll<Option<Self::Item>> {
243 let this = self.project();
244
245 match this.state {
246 DecryptState::PostHeader(state) => {
247 if state.eof && this.buffer.len() <= BYTES_PER_POLL {
248 if state.eof_buf.len() < HMAC_LEN {
249 *this.state = DecryptState::Done;
250 return Poll::Ready(Some(Err(DecryptStreamError::TooShort)));
251 }
252 debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
253
254 let mut hmac = state.integrity_code.take().expect("integrity_code only taken here");
255 let mut data = this.buffer.split();
256
257 hmac.update(&data);
258 state.aes.apply_keystream(&mut data);
259
260 let stored_integrity_code: &[u8; HMAC_LEN] = state.eof_buf.as_ref().try_into().unwrap();
261 if !constant_time_eq_32(stored_integrity_code, hmac.finalize().into_bytes().as_ref()) {
262 *this.state = DecryptState::Done;
263 return Poll::Ready(Some(Err(DecryptStreamError::IntegrityFailed)));
264 }
265
266 *this.state = DecryptState::Done;
267
268 Poll::Ready(Some(Ok(data.freeze())))
269 } else if this.buffer.len() >= BYTES_PER_POLL {
270 let mut data = this.buffer.split_to(BYTES_PER_POLL);
271
272 state.integrity_code.as_mut().unwrap().update(&data);
273 state.aes.apply_keystream(&mut data);
274
275 Poll::Ready(Some(Ok(data.freeze())))
276 } else {
277 match ready!(this.read.poll_next(cx)) {
278 Some(Ok(bytes)) => {
279 state.eof_buf.put(bytes);
280 let eof_len = state.eof_buf.len();
281 if eof_len > HMAC_LEN {
282 this.buffer.put(state.eof_buf.split_to(eof_len - HMAC_LEN));
283 debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
284 }
285 cx.waker().wake_by_ref();
286 Poll::Pending
287 },
288 Some(Err(e)) => {
289 *this.state = DecryptState::Done;
290 Poll::Ready(Some(Err(DecryptStreamError::Stream(e))))
291 },
292 None => {
293 debug_assert!(state.eof_buf.len() <= HMAC_LEN);
294 state.eof = true;
295 cx.waker().wake_by_ref();
296 Poll::Pending
297 }
298 }
299 }
300 },
301 DecryptState::Done => Poll::Ready(None)
302 }
303 }
304}