ssec_core/
decrypt.rs

1use futures_core::Stream;
2use bytes::{Bytes, BytesMut, BufMut};
3use ctr::cipher::{KeyIvInit, StreamCipher};
4use hmac::{Mac, KeyInit};
5use constant_time_eq::{constant_time_eq_32, constant_time_eq_64};
6use core::pin::Pin;
7use core::fmt::Display;
8use core::error::Error;
9use core::task::{Context, Poll, ready};
10use crate::util::{HmacSha3_256, kdf, compute_verification_hash};
11use crate::{BYTES_PER_POLL, Aes256Ctr};
12
13pin_project_lite::pin_project! {
14	pub struct Decrypt<R> {
15		#[pin]
16		read: Option<R>,
17		buffer: Option<BytesMut>
18	}
19}
20
21impl<R> Decrypt<R> {
22	pub fn new(read: R) -> Self {
23		Self {
24			read: Some(read),
25			buffer: Some(BytesMut::new())
26		}
27	}
28}
29
30#[derive(Debug)]
31pub enum SsecHeaderError<E> {
32	NotSsec,
33	UnsupportedVersion(u8),
34	UnsupportedCompression(u8),
35	Stream(E)
36}
37
38impl<E: Display> Display for SsecHeaderError<E> {
39	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40		match self {
41			Self::NotSsec => write!(f, "wrapped stream did not produce a SSEC file"),
42			Self::UnsupportedVersion(v) => write!(f, "SSEC file version {v:?} is unsupported"),
43			Self::UnsupportedCompression(c) => write!(f, "SSEC compression algorithm {c:?} is valid but currently unsupported"),
44			Self::Stream(e) => e.fmt(f)
45		}
46	}
47}
48
49impl<E> Error for SsecHeaderError<E>
50where
51	E: Error + 'static,
52	Self: Display
53{
54	#[inline]
55	fn source(&self) -> Option<&(dyn Error + 'static)> {
56		match self {
57			Self::NotSsec => None,
58			Self::UnsupportedVersion(_) => None,
59			Self::UnsupportedCompression(_) => None,
60			Self::Stream(e) => Some(e)
61		}
62	}
63}
64
65const HEADER_LEN: usize = 118;
66
67impl<E, R: Stream<Item = Result<Bytes, E>> + Unpin> Future for Decrypt<R> {
68	type Output = Result<Box<DecryptAwaitingPassword<R>>, SsecHeaderError<E>>;
69
70	fn poll(
71		self: Pin<&mut Self>,
72		cx: &mut Context<'_>
73	) -> Poll<Self::Output> {
74		let this = self.project();
75
76		if this.buffer.as_ref().unwrap().len() >= HEADER_LEN {
77			let read = this.read.get_mut().take().unwrap();
78			let mut buffer = this.buffer.take().unwrap();
79
80			let header = buffer.split_to(HEADER_LEN);
81
82			if &header[0..=3] != b"SSEC" {
83				return Poll::Ready(Err(SsecHeaderError::NotSsec));
84			}
85
86			if header[4] != 0x01 {
87				return Poll::Ready(Err(SsecHeaderError::UnsupportedVersion(header[4])));
88			}
89
90			if header[5] != 0x6e {
91				return Poll::Ready(Err(match header[5] {
92					0x62 => SsecHeaderError::UnsupportedCompression(0x62),
93					_ => SsecHeaderError::NotSsec
94				}));
95			}
96
97			let salt: [u8; 32] = header[6..=37].try_into().unwrap();
98			let verification_hash: [u8; 64] = header[38..=101].try_into().unwrap();
99			let iv: [u8; 16] = header[102..HEADER_LEN].try_into().unwrap();
100
101			Poll::Ready(Ok(Box::new(DecryptAwaitingPassword {
102				read,
103				buffer,
104				salt,
105				verification_hash,
106				iv,
107				version_byte: header[4],
108				compression_algo: header[5]
109			})))
110		} else {
111			let read = this.read.as_pin_mut().unwrap();
112
113			match read.poll_next(cx) {
114				Poll::Ready(Some(Ok(bytes))) => {
115					this.buffer.as_mut().unwrap().put(bytes);
116					cx.waker().wake_by_ref();
117					Poll::Pending
118				},
119				Poll::Ready(Some(Err(e))) => Poll::Ready(Err(SsecHeaderError::Stream(e))),
120				Poll::Ready(None) => Poll::Ready(Err(SsecHeaderError::NotSsec)),
121				Poll::Pending => Poll::Pending
122			}
123		}
124	}
125}
126
127pub struct DecryptAwaitingPassword<R> {
128	read: R,
129	buffer: BytesMut,
130	salt: [u8; 32],
131	verification_hash: [u8; 64],
132	iv: [u8; 16],
133	version_byte: u8,
134	compression_algo: u8
135}
136
137const HMAC_LEN: usize = 32;
138
139impl<R> DecryptAwaitingPassword<R> {
140	/// This method is *very* blocking.
141	/// If you're using Tokio I advise that you wrap this call in a `spawn_blocking`.
142	///
143	/// If a `Result::Err` is returned it indicates the password was incorrect.
144	///
145	/// SECURITY: It is advisable to zero out the memory containing the password after this method returns.
146	pub fn try_password(mut self: Box<Self>, password: &[u8]) -> Result<DecryptStream<R>, Box<Self>> {
147		let key = kdf(password, &self.salt);
148
149		if constant_time_eq_64(compute_verification_hash(&key).as_ref(), &self.verification_hash) {
150			let mut integrity_code = HmacSha3_256::new_from_slice(key.as_ref().get_ref()).unwrap();
151			integrity_code.update(&[self.version_byte, self.compression_algo]);
152			integrity_code.update(&self.iv);
153
154			let buf_len = self.buffer.len();
155			let eof_buf = if buf_len >= HMAC_LEN {
156				self.buffer.split_off(buf_len - HMAC_LEN)
157			} else {
158				self.buffer.split()
159			};
160			debug_assert!(eof_buf.len() <= HMAC_LEN);
161
162			let state = DecryptState::PostHeader(Box::new(DecryptionState {
163				aes: Aes256Ctr::new(key.as_ref().get_ref().into(), (&self.iv).into()),
164				integrity_code: Some(integrity_code),
165				eof: false,
166				eof_buf
167			}));
168
169			Ok(DecryptStream {
170				read: self.read,
171				state,
172				buffer: self.buffer
173			})
174		} else {
175			Err(self)
176		}
177	}
178}
179
180struct DecryptionState {
181	aes: Aes256Ctr,
182	integrity_code: Option<HmacSha3_256>,
183	eof: bool,
184	eof_buf: BytesMut
185}
186
187enum DecryptState {
188	PostHeader(Box<DecryptionState>),
189	Done
190}
191
192pin_project_lite::pin_project! {
193	pub struct DecryptStream<R> {
194		#[pin]
195		read: R,
196		state: DecryptState,
197		buffer: BytesMut,
198	}
199}
200
201#[derive(Debug)]
202pub enum DecryptStreamError<E> {
203	TooShort,
204	IntegrityFailed,
205	Stream(E)
206}
207
208impl<E: Display> Display for DecryptStreamError<E> {
209	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
210		match self {
211			Self::TooShort => write!(f, "wrapped stream was too short to have been a valid SSEC file (no integrity code)"),
212			Self::IntegrityFailed => write!(f, "the file has been tampered with, previously decrypted data is inauthentic and should be discarded"),
213			Self::Stream(e) => e.fmt(f)
214		}
215	}
216}
217
218impl<E> Error for DecryptStreamError<E>
219where
220	E: Error + 'static,
221	Self: Display
222{
223	#[inline]
224	fn source(&self) -> Option<&(dyn Error + 'static)> {
225		match self {
226			Self::TooShort => None,
227			Self::IntegrityFailed => None,
228			Self::Stream(e) => Some(e)
229		}
230	}
231}
232
233impl<E, R> Stream for DecryptStream<R>
234where
235	R: Stream<Item = Result<Bytes, E>>
236{
237	type Item = Result<Bytes, DecryptStreamError<E>>;
238
239	fn poll_next(
240		self: Pin<&mut Self>,
241		cx: &mut Context<'_>
242	) -> Poll<Option<Self::Item>> {
243		let this = self.project();
244
245		match this.state {
246			DecryptState::PostHeader(state) => {
247				if state.eof && this.buffer.len() <= BYTES_PER_POLL {
248					if state.eof_buf.len() < HMAC_LEN {
249						*this.state = DecryptState::Done;
250						return Poll::Ready(Some(Err(DecryptStreamError::TooShort)));
251					}
252					debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
253
254					let mut hmac = state.integrity_code.take().expect("integrity_code only taken here");
255					let mut data = this.buffer.split();
256
257					hmac.update(&data);
258					state.aes.apply_keystream(&mut data);
259
260					let stored_integrity_code: &[u8; HMAC_LEN] = state.eof_buf.as_ref().try_into().unwrap();
261					if !constant_time_eq_32(stored_integrity_code, hmac.finalize().into_bytes().as_ref()) {
262						*this.state = DecryptState::Done;
263						return Poll::Ready(Some(Err(DecryptStreamError::IntegrityFailed)));
264					}
265
266					*this.state = DecryptState::Done;
267
268					Poll::Ready(Some(Ok(data.freeze())))
269				} else if this.buffer.len() >= BYTES_PER_POLL {
270					let mut data = this.buffer.split_to(BYTES_PER_POLL);
271
272					state.integrity_code.as_mut().unwrap().update(&data);
273					state.aes.apply_keystream(&mut data);
274
275					Poll::Ready(Some(Ok(data.freeze())))
276				} else {
277					match ready!(this.read.poll_next(cx)) {
278						Some(Ok(bytes)) => {
279							state.eof_buf.put(bytes);
280							let eof_len = state.eof_buf.len();
281							if eof_len > HMAC_LEN {
282								this.buffer.put(state.eof_buf.split_to(eof_len - HMAC_LEN));
283								debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
284							}
285							cx.waker().wake_by_ref();
286							Poll::Pending
287						},
288						Some(Err(e)) => {
289							*this.state = DecryptState::Done;
290							Poll::Ready(Some(Err(DecryptStreamError::Stream(e))))
291						},
292						None => {
293							debug_assert!(state.eof_buf.len() <= HMAC_LEN);
294							state.eof = true;
295							cx.waker().wake_by_ref();
296							Poll::Pending
297						}
298					}
299				}
300			},
301			DecryptState::Done => Poll::Ready(None)
302		}
303	}
304}