1use futures_core::Stream;
2use bytes::{Bytes, BytesMut, BufMut};
3use ctr::cipher::{KeyIvInit, StreamCipher};
4use hmac::{Mac, KeyInit};
5use constant_time_eq::{constant_time_eq_32, constant_time_eq_64};
6use core::pin::Pin;
7use core::fmt::Display;
8use core::error::Error;
9use core::task::{Context, Poll, ready};
10use core::num::NonZeroUsize;
11use crate::util::{HmacSha3_256, kdf, compute_verification_hash};
12use crate::{DEFAULT_BYTES_PER_POLL, Aes256Ctr};
13
14#[derive(Debug, Clone, Copy)]
16pub struct DecryptArgs {
17 bytes_per_poll: NonZeroUsize
18}
19
20impl Default for DecryptArgs {
21 fn default() -> Self {
23 Self {
24 bytes_per_poll: DEFAULT_BYTES_PER_POLL
25 }
26 }
27}
28
29impl DecryptArgs {
30 pub fn set_bytes_per_poll(&mut self, bytes_per_poll: NonZeroUsize) {
32 self.bytes_per_poll = bytes_per_poll;
33 }
34}
35
36pin_project_lite::pin_project! {
37 pub struct Decrypt<R> {
38 #[pin]
39 read: Option<R>,
40 buffer: Option<BytesMut>,
41 bytes_per_poll: NonZeroUsize
42 }
43}
44
45impl<R> Decrypt<R> {
46 pub fn new(additional_args: DecryptArgs, read: R) -> Self {
47 Self {
48 read: Some(read),
49 buffer: Some(BytesMut::new()),
50 bytes_per_poll: additional_args.bytes_per_poll
51 }
52 }
53}
54
55#[derive(Debug)]
56pub enum SsecHeaderError<E> {
57 NotSsec,
58 UnsupportedVersion(u8),
59 UnsupportedCompression(u8),
60 Stream(E)
61}
62
63impl<E: Display> Display for SsecHeaderError<E> {
64 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65 match self {
66 Self::NotSsec => write!(f, "wrapped stream did not produce a SSEC file"),
67 Self::UnsupportedVersion(v) => write!(f, "SSEC file version {v:?} is unsupported"),
68 Self::UnsupportedCompression(c) => write!(f, "SSEC compression algorithm {c:?} is valid but currently unsupported"),
69 Self::Stream(e) => e.fmt(f)
70 }
71 }
72}
73
74impl<E> Error for SsecHeaderError<E>
75where
76 E: Error + 'static,
77 Self: Display
78{
79 #[inline]
80 fn source(&self) -> Option<&(dyn Error + 'static)> {
81 match self {
82 Self::NotSsec => None,
83 Self::UnsupportedVersion(_) => None,
84 Self::UnsupportedCompression(_) => None,
85 Self::Stream(e) => Some(e)
86 }
87 }
88}
89
90const HEADER_LEN: usize = 118;
91
92impl<E, R: Stream<Item = Result<Bytes, E>> + Unpin> Future for Decrypt<R> {
93 type Output = Result<Box<DecryptAwaitingPassword<R>>, SsecHeaderError<E>>;
94
95 fn poll(
96 self: Pin<&mut Self>,
97 cx: &mut Context<'_>
98 ) -> Poll<Self::Output> {
99 let this = self.project();
100
101 if this.buffer.as_ref().unwrap().len() >= HEADER_LEN {
102 let read = this.read.get_mut().take().unwrap();
103 let mut buffer = this.buffer.take().unwrap();
104
105 let header = buffer.split_to(HEADER_LEN);
106
107 if &header[0..=3] != b"SSEC" {
108 return Poll::Ready(Err(SsecHeaderError::NotSsec));
109 }
110
111 if header[4] != 0x01 {
112 return Poll::Ready(Err(SsecHeaderError::UnsupportedVersion(header[4])));
113 }
114
115 if header[5] != 0x6e {
116 return Poll::Ready(Err(match header[5] {
117 0x62 => SsecHeaderError::UnsupportedCompression(0x62),
118 _ => SsecHeaderError::NotSsec
119 }));
120 }
121
122 let salt: [u8; 32] = header[6..=37].try_into().unwrap();
123 let verification_hash: [u8; 64] = header[38..=101].try_into().unwrap();
124 let iv: [u8; 16] = header[102..HEADER_LEN].try_into().unwrap();
125
126 Poll::Ready(Ok(Box::new(DecryptAwaitingPassword {
127 read,
128 buffer,
129 salt,
130 verification_hash,
131 iv,
132 version_byte: header[4],
133 compression_algo: header[5],
134 bytes_per_poll: *this.bytes_per_poll
135 })))
136 } else {
137 let read = this.read.as_pin_mut().unwrap();
138
139 match read.poll_next(cx) {
140 Poll::Ready(Some(Ok(bytes))) => {
141 this.buffer.as_mut().unwrap().put(bytes);
142 cx.waker().wake_by_ref();
143 Poll::Pending
144 },
145 Poll::Ready(Some(Err(e))) => Poll::Ready(Err(SsecHeaderError::Stream(e))),
146 Poll::Ready(None) => Poll::Ready(Err(SsecHeaderError::NotSsec)),
147 Poll::Pending => Poll::Pending
148 }
149 }
150 }
151}
152
153pub struct DecryptAwaitingPassword<R> {
154 read: R,
155 buffer: BytesMut,
156 salt: [u8; 32],
157 verification_hash: [u8; 64],
158 iv: [u8; 16],
159 version_byte: u8,
160 compression_algo: u8,
161 bytes_per_poll: NonZeroUsize
162}
163
164const HMAC_LEN: usize = 32;
165
166impl<R> DecryptAwaitingPassword<R> {
167 pub fn try_password(mut self: Box<Self>, password: &[u8]) -> Result<DecryptStream<R>, Box<Self>> {
174 let key = kdf(password, &self.salt);
175
176 if constant_time_eq_64(compute_verification_hash(&key).as_ref(), &self.verification_hash) {
177 let mut integrity_code = HmacSha3_256::new_from_slice(key.as_ref().get_ref()).unwrap();
178 integrity_code.update(&[self.version_byte, self.compression_algo]);
179 integrity_code.update(&self.iv);
180
181 let buf_len = self.buffer.len();
182 let eof_buf = if buf_len >= HMAC_LEN {
183 self.buffer.split_off(buf_len - HMAC_LEN)
184 } else {
185 self.buffer.split()
186 };
187 debug_assert!(eof_buf.len() <= HMAC_LEN);
188
189 let state = DecryptState::PostHeader(Box::new(DecryptionState {
190 aes: Aes256Ctr::new(key.as_ref().get_ref().into(), (&self.iv).into()),
191 integrity_code: Some(integrity_code),
192 eof: false,
193 eof_buf
194 }));
195
196 Ok(DecryptStream {
197 read: self.read,
198 state,
199 buffer: self.buffer,
200 bytes_per_poll: self.bytes_per_poll
201 })
202 } else {
203 Err(self)
204 }
205 }
206}
207
208struct DecryptionState {
209 aes: Aes256Ctr,
210 integrity_code: Option<HmacSha3_256>,
211 eof: bool,
212 eof_buf: BytesMut
213}
214
215enum DecryptState {
216 PostHeader(Box<DecryptionState>),
217 Done
218}
219
220pin_project_lite::pin_project! {
221 pub struct DecryptStream<R> {
222 #[pin]
223 read: R,
224 state: DecryptState,
225 buffer: BytesMut,
226 bytes_per_poll: NonZeroUsize
227 }
228}
229
230#[derive(Debug)]
231pub enum DecryptStreamError<E> {
232 TooShort,
233 IntegrityFailed,
234 Stream(E)
235}
236
237impl<E: Display> Display for DecryptStreamError<E> {
238 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
239 match self {
240 Self::TooShort => write!(f, "wrapped stream was too short to have been a valid SSEC file (no integrity code)"),
241 Self::IntegrityFailed => write!(f, "the file has been tampered with, previously decrypted data is inauthentic and should be discarded"),
242 Self::Stream(e) => e.fmt(f)
243 }
244 }
245}
246
247impl<E> Error for DecryptStreamError<E>
248where
249 E: Error + 'static,
250 Self: Display
251{
252 #[inline]
253 fn source(&self) -> Option<&(dyn Error + 'static)> {
254 match self {
255 Self::TooShort => None,
256 Self::IntegrityFailed => None,
257 Self::Stream(e) => Some(e)
258 }
259 }
260}
261
262impl<E, R> Stream for DecryptStream<R>
263where
264 R: Stream<Item = Result<Bytes, E>>
265{
266 type Item = Result<Bytes, DecryptStreamError<E>>;
267
268 fn poll_next(
269 self: Pin<&mut Self>,
270 cx: &mut Context<'_>
271 ) -> Poll<Option<Self::Item>> {
272 let this = self.project();
273
274 match this.state {
275 DecryptState::PostHeader(state) => {
276 if state.eof && this.buffer.len() <= this.bytes_per_poll.get() {
277 if state.eof_buf.len() < HMAC_LEN {
278 *this.state = DecryptState::Done;
279 return Poll::Ready(Some(Err(DecryptStreamError::TooShort)));
280 }
281 debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
282
283 let mut hmac = state.integrity_code.take().expect("integrity_code only taken here");
284 let mut data = this.buffer.split();
285
286 hmac.update(&data);
287 state.aes.apply_keystream(&mut data);
288
289 let stored_integrity_code: &[u8; HMAC_LEN] = state.eof_buf.as_ref().try_into().unwrap();
290 if !constant_time_eq_32(stored_integrity_code, hmac.finalize().into_bytes().as_ref()) {
291 *this.state = DecryptState::Done;
292 return Poll::Ready(Some(Err(DecryptStreamError::IntegrityFailed)));
293 }
294
295 *this.state = DecryptState::Done;
296
297 Poll::Ready(Some(Ok(data.freeze())))
298 } else if this.buffer.len() >= this.bytes_per_poll.get() {
299 let mut data = this.buffer.split_to(this.bytes_per_poll.get());
300
301 state.integrity_code.as_mut().unwrap().update(&data);
302 state.aes.apply_keystream(&mut data);
303
304 Poll::Ready(Some(Ok(data.freeze())))
305 } else {
306 match ready!(this.read.poll_next(cx)) {
307 Some(Ok(bytes)) => {
308 state.eof_buf.put(bytes);
309 let eof_len = state.eof_buf.len();
310 if eof_len > HMAC_LEN {
311 this.buffer.put(state.eof_buf.split_to(eof_len - HMAC_LEN));
312 debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
313 }
314 cx.waker().wake_by_ref();
315 Poll::Pending
316 },
317 Some(Err(e)) => {
318 *this.state = DecryptState::Done;
319 Poll::Ready(Some(Err(DecryptStreamError::Stream(e))))
320 },
321 None => {
322 debug_assert!(state.eof_buf.len() <= HMAC_LEN);
323 state.eof = true;
324 cx.waker().wake_by_ref();
325 Poll::Pending
326 }
327 }
328 }
329 },
330 DecryptState::Done => Poll::Ready(None)
331 }
332 }
333}