Skip to main content

ssec_core/
decrypt.rs

1use futures_core::Stream;
2use bytes::{Bytes, BytesMut, BufMut};
3use ctr::cipher::{KeyIvInit, StreamCipher};
4use hmac::{Mac, KeyInit};
5use constant_time_eq::{constant_time_eq_32, constant_time_eq_64};
6use core::pin::Pin;
7use core::fmt::Display;
8use core::error::Error;
9use core::task::{Context, Poll, ready};
10use core::num::NonZeroUsize;
11use crate::util::{HmacSha3_256, kdf, compute_verification_hash};
12use crate::{DEFAULT_BYTES_PER_POLL, Aes256Ctr};
13
14/// builder for arguments to [Decrypt::new] with default values, can be constructed with [DecryptArgs::default]
15#[derive(Debug, Clone, Copy)]
16pub struct DecryptArgs {
17	bytes_per_poll: NonZeroUsize
18}
19
20impl Default for DecryptArgs {
21	/// default settings are not part of semver contract
22	fn default() -> Self {
23		Self {
24			bytes_per_poll: DEFAULT_BYTES_PER_POLL
25		}
26	}
27}
28
29impl DecryptArgs {
30	/// sets the maximum number of bytes to decrypt before yielding to the executor
31	pub fn set_bytes_per_poll(&mut self, bytes_per_poll: NonZeroUsize) {
32		self.bytes_per_poll = bytes_per_poll;
33	}
34}
35
36pin_project_lite::pin_project! {
37	pub struct Decrypt<R> {
38		#[pin]
39		read: Option<R>,
40		buffer: Option<BytesMut>,
41		bytes_per_poll: NonZeroUsize
42	}
43}
44
45impl<R> Decrypt<R> {
46	pub fn new(additional_args: DecryptArgs, read: R) -> Self {
47		Self {
48			read: Some(read),
49			buffer: Some(BytesMut::new()),
50			bytes_per_poll: additional_args.bytes_per_poll
51		}
52	}
53}
54
55#[derive(Debug)]
56pub enum SsecHeaderError<E> {
57	NotSsec,
58	UnsupportedVersion(u8),
59	UnsupportedCompression(u8),
60	Stream(E)
61}
62
63impl<E: Display> Display for SsecHeaderError<E> {
64	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65		match self {
66			Self::NotSsec => write!(f, "wrapped stream did not produce a SSEC file"),
67			Self::UnsupportedVersion(v) => write!(f, "SSEC file version {v:?} is unsupported"),
68			Self::UnsupportedCompression(c) => write!(f, "SSEC compression algorithm {c:?} is valid but currently unsupported"),
69			Self::Stream(e) => e.fmt(f)
70		}
71	}
72}
73
74impl<E> Error for SsecHeaderError<E>
75where
76	E: Error + 'static,
77	Self: Display
78{
79	#[inline]
80	fn source(&self) -> Option<&(dyn Error + 'static)> {
81		match self {
82			Self::NotSsec => None,
83			Self::UnsupportedVersion(_) => None,
84			Self::UnsupportedCompression(_) => None,
85			Self::Stream(e) => Some(e)
86		}
87	}
88}
89
90const HEADER_LEN: usize = 118;
91
92impl<E, R: Stream<Item = Result<Bytes, E>> + Unpin> Future for Decrypt<R> {
93	type Output = Result<Box<DecryptAwaitingPassword<R>>, SsecHeaderError<E>>;
94
95	fn poll(
96		self: Pin<&mut Self>,
97		cx: &mut Context<'_>
98	) -> Poll<Self::Output> {
99		let this = self.project();
100
101		if this.buffer.as_ref().unwrap().len() >= HEADER_LEN {
102			let read = this.read.get_mut().take().unwrap();
103			let mut buffer = this.buffer.take().unwrap();
104
105			let header = buffer.split_to(HEADER_LEN);
106
107			if &header[0..=3] != b"SSEC" {
108				return Poll::Ready(Err(SsecHeaderError::NotSsec));
109			}
110
111			if header[4] != 0x01 {
112				return Poll::Ready(Err(SsecHeaderError::UnsupportedVersion(header[4])));
113			}
114
115			if header[5] != 0x6e {
116				return Poll::Ready(Err(match header[5] {
117					0x62 => SsecHeaderError::UnsupportedCompression(0x62),
118					_ => SsecHeaderError::NotSsec
119				}));
120			}
121
122			let salt: [u8; 32] = header[6..=37].try_into().unwrap();
123			let verification_hash: [u8; 64] = header[38..=101].try_into().unwrap();
124			let iv: [u8; 16] = header[102..HEADER_LEN].try_into().unwrap();
125
126			Poll::Ready(Ok(Box::new(DecryptAwaitingPassword {
127				read,
128				buffer,
129				salt,
130				verification_hash,
131				iv,
132				version_byte: header[4],
133				compression_algo: header[5],
134				bytes_per_poll: *this.bytes_per_poll
135			})))
136		} else {
137			let read = this.read.as_pin_mut().unwrap();
138
139			match read.poll_next(cx) {
140				Poll::Ready(Some(Ok(bytes))) => {
141					this.buffer.as_mut().unwrap().put(bytes);
142					cx.waker().wake_by_ref();
143					Poll::Pending
144				},
145				Poll::Ready(Some(Err(e))) => Poll::Ready(Err(SsecHeaderError::Stream(e))),
146				Poll::Ready(None) => Poll::Ready(Err(SsecHeaderError::NotSsec)),
147				Poll::Pending => Poll::Pending
148			}
149		}
150	}
151}
152
153pub struct DecryptAwaitingPassword<R> {
154	read: R,
155	buffer: BytesMut,
156	salt: [u8; 32],
157	verification_hash: [u8; 64],
158	iv: [u8; 16],
159	version_byte: u8,
160	compression_algo: u8,
161	bytes_per_poll: NonZeroUsize
162}
163
164const HMAC_LEN: usize = 32;
165
166impl<R> DecryptAwaitingPassword<R> {
167	/// This method is *very* blocking.
168	/// If you're using Tokio I advise that you wrap this call in a `spawn_blocking`.
169	///
170	/// If a `Result::Err` is returned it indicates the password was incorrect.
171	///
172	/// SECURITY: It is advisable to zero out the memory containing the password after this method returns.
173	pub fn try_password(mut self: Box<Self>, password: &[u8]) -> Result<DecryptStream<R>, Box<Self>> {
174		let key = kdf(password, &self.salt);
175
176		if constant_time_eq_64(compute_verification_hash(&key).as_ref(), &self.verification_hash) {
177			let mut integrity_code = HmacSha3_256::new_from_slice(key.as_ref().get_ref()).unwrap();
178			integrity_code.update(&[self.version_byte, self.compression_algo]);
179			integrity_code.update(&self.iv);
180
181			let buf_len = self.buffer.len();
182			let eof_buf = if buf_len >= HMAC_LEN {
183				self.buffer.split_off(buf_len - HMAC_LEN)
184			} else {
185				self.buffer.split()
186			};
187			debug_assert!(eof_buf.len() <= HMAC_LEN);
188
189			let state = DecryptState::PostHeader(Box::new(DecryptionState {
190				aes: Aes256Ctr::new(key.as_ref().get_ref().into(), (&self.iv).into()),
191				integrity_code: Some(integrity_code),
192				eof: false,
193				eof_buf
194			}));
195
196			Ok(DecryptStream {
197				read: self.read,
198				state,
199				buffer: self.buffer,
200				bytes_per_poll: self.bytes_per_poll
201			})
202		} else {
203			Err(self)
204		}
205	}
206}
207
208struct DecryptionState {
209	aes: Aes256Ctr,
210	integrity_code: Option<HmacSha3_256>,
211	eof: bool,
212	eof_buf: BytesMut
213}
214
215enum DecryptState {
216	PostHeader(Box<DecryptionState>),
217	Done
218}
219
220pin_project_lite::pin_project! {
221	pub struct DecryptStream<R> {
222		#[pin]
223		read: R,
224		state: DecryptState,
225		buffer: BytesMut,
226		bytes_per_poll: NonZeroUsize
227	}
228}
229
230#[derive(Debug)]
231pub enum DecryptStreamError<E> {
232	TooShort,
233	IntegrityFailed,
234	Stream(E)
235}
236
237impl<E: Display> Display for DecryptStreamError<E> {
238	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
239		match self {
240			Self::TooShort => write!(f, "wrapped stream was too short to have been a valid SSEC file (no integrity code)"),
241			Self::IntegrityFailed => write!(f, "the file has been tampered with, previously decrypted data is inauthentic and should be discarded"),
242			Self::Stream(e) => e.fmt(f)
243		}
244	}
245}
246
247impl<E> Error for DecryptStreamError<E>
248where
249	E: Error + 'static,
250	Self: Display
251{
252	#[inline]
253	fn source(&self) -> Option<&(dyn Error + 'static)> {
254		match self {
255			Self::TooShort => None,
256			Self::IntegrityFailed => None,
257			Self::Stream(e) => Some(e)
258		}
259	}
260}
261
262impl<E, R> Stream for DecryptStream<R>
263where
264	R: Stream<Item = Result<Bytes, E>>
265{
266	type Item = Result<Bytes, DecryptStreamError<E>>;
267
268	fn poll_next(
269		self: Pin<&mut Self>,
270		cx: &mut Context<'_>
271	) -> Poll<Option<Self::Item>> {
272		let this = self.project();
273
274		match this.state {
275			DecryptState::PostHeader(state) => {
276				if state.eof && this.buffer.len() <= this.bytes_per_poll.get() {
277					if state.eof_buf.len() < HMAC_LEN {
278						*this.state = DecryptState::Done;
279						return Poll::Ready(Some(Err(DecryptStreamError::TooShort)));
280					}
281					debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
282
283					let mut hmac = state.integrity_code.take().expect("integrity_code only taken here");
284					let mut data = this.buffer.split();
285
286					hmac.update(&data);
287					state.aes.apply_keystream(&mut data);
288
289					let stored_integrity_code: &[u8; HMAC_LEN] = state.eof_buf.as_ref().try_into().unwrap();
290					if !constant_time_eq_32(stored_integrity_code, hmac.finalize().into_bytes().as_ref()) {
291						*this.state = DecryptState::Done;
292						return Poll::Ready(Some(Err(DecryptStreamError::IntegrityFailed)));
293					}
294
295					*this.state = DecryptState::Done;
296
297					Poll::Ready(Some(Ok(data.freeze())))
298				} else if this.buffer.len() >= this.bytes_per_poll.get() {
299					let mut data = this.buffer.split_to(this.bytes_per_poll.get());
300
301					state.integrity_code.as_mut().unwrap().update(&data);
302					state.aes.apply_keystream(&mut data);
303
304					Poll::Ready(Some(Ok(data.freeze())))
305				} else {
306					match ready!(this.read.poll_next(cx)) {
307						Some(Ok(bytes)) => {
308							state.eof_buf.put(bytes);
309							let eof_len = state.eof_buf.len();
310							if eof_len > HMAC_LEN {
311								this.buffer.put(state.eof_buf.split_to(eof_len - HMAC_LEN));
312								debug_assert_eq!(state.eof_buf.len(), HMAC_LEN);
313							}
314							cx.waker().wake_by_ref();
315							Poll::Pending
316						},
317						Some(Err(e)) => {
318							*this.state = DecryptState::Done;
319							Poll::Ready(Some(Err(DecryptStreamError::Stream(e))))
320						},
321						None => {
322							debug_assert!(state.eof_buf.len() <= HMAC_LEN);
323							state.eof = true;
324							cx.waker().wake_by_ref();
325							Poll::Pending
326						}
327					}
328				}
329			},
330			DecryptState::Done => Poll::Ready(None)
331		}
332	}
333}