Skip to main content

securer_string/secure_types/
vec.rs

1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::str::FromStr;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10/// A data type suitable for storing sensitive information such as passwords and
11/// private keys in memory, that implements:
12///
13/// - Automatic zeroing in `Drop`
14/// - Constant time comparison in `PartialEq` (does not short circuit on the
15///   first different character; but terminates instantly if strings have
16///   different length)
17/// - Outputting `***SECRET***` to prevent leaking secrets into logs in
18///   `fmt::Debug` and `fmt::Display`
19/// - Automatic `mlock` to protect against leaking into swap (any unix)
20/// - Automatic `madvise(MADV_NOCORE/MADV_DONTDUMP)` to protect against leaking
21///   into core dumps (FreeBSD, DragonflyBSD, Linux)
22///
23/// `PartialEq` and `Eq` are only implemented when `T: ConstantTimeEq`. The
24/// safety of comparisons with respect to padding bytes depends on the
25/// `ConstantTimeEq` implementation of `T`.
26///
27/// Be careful with `SecureBytes::from`: if you have a borrowed string, it will
28/// be copied. Use `SecureBytes::new` if you have a `Vec<u8>`.
29pub struct SecureVec<T>
30where
31    T: Copy + Zeroize,
32{
33    pub(crate) content: Vec<T>,
34}
35
36/// Type alias for a vector that stores just bytes
37pub type SecureBytes = SecureVec<u8>;
38
39impl<T> SecureVec<T>
40where
41    T: Copy + Zeroize,
42{
43    #[must_use]
44    pub fn new(mut cont: Vec<T>) -> Self {
45        memlock::mlock(cont.as_mut_ptr(), cont.capacity());
46        SecureVec { content: cont }
47    }
48
49    /// Borrow the contents of the string.
50    #[must_use]
51    pub fn unsecure(&self) -> &[T] {
52        self.borrow()
53    }
54
55    /// Mutably borrow the contents of the string.
56    pub fn unsecure_mut(&mut self) -> &mut [T] {
57        self.borrow_mut()
58    }
59
60    /// Resizes the `SecureVec` in-place so that len is equal to `new_len`.
61    ///
62    /// If `new_len` is smaller the inner vector is truncated.
63    /// If `new_len` is larger the inner vector will grow, placing `value` in
64    /// all new cells.
65    ///
66    /// This ensures that the new memory region is secured if reallocation
67    /// occurs.
68    ///
69    /// Similar to [`Vec::resize`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.resize)
70    pub fn resize(&mut self, new_len: usize, value: T) {
71        // Trucnate if shorter or same length
72        if new_len <= self.content.len() {
73            self.content.truncate(new_len);
74            return;
75        }
76
77        // Allocate new vector, copy old data into it
78        let mut new_vec = vec![value; new_len];
79        memlock::mlock(new_vec.as_mut_ptr(), new_vec.capacity());
80        new_vec[0..self.content.len()].copy_from_slice(&self.content);
81
82        // Securely clear old vector, replace with new vector
83        self.zero_out();
84        memlock::munlock(self.content.as_mut_ptr(), self.content.capacity());
85        self.content = new_vec;
86    }
87
88    /// Overwrite the string with zeros. This is automatically called in the
89    /// destructor.
90    ///
91    /// This also sets the length to `0`.
92    pub fn zero_out(&mut self) {
93        self.content.zeroize();
94    }
95}
96
97impl<T: Copy + Zeroize> Clone for SecureVec<T> {
98    fn clone(&self) -> Self {
99        Self::new(self.content.clone())
100    }
101}
102
103impl<T: Copy + Zeroize + ConstantTimeEq> ConstantTimeEq for SecureVec<T> {
104    fn ct_eq(&self, other: &Self) -> subtle::Choice {
105        self.content.ct_eq(&other.content)
106    }
107}
108
109impl<T: Copy + Zeroize + ConstantTimeEq> PartialEq for SecureVec<T> {
110    fn eq(&self, other: &Self) -> bool {
111        self.ct_eq(other).into()
112    }
113}
114
115impl<T: Copy + Zeroize + ConstantTimeEq> Eq for SecureVec<T> {}
116
117// Creation
118impl<T, U> From<U> for SecureVec<T>
119where
120    U: Into<Vec<T>>,
121    T: Copy + Zeroize,
122{
123    fn from(s: U) -> SecureVec<T> {
124        SecureVec::new(s.into())
125    }
126}
127
128impl FromStr for SecureVec<u8> {
129    type Err = std::convert::Infallible;
130
131    fn from_str(s: &str) -> Result<Self, Self::Err> {
132        Ok(SecureVec::new(s.into()))
133    }
134}
135
136// Vec item indexing
137impl<T, U> std::ops::Index<U> for SecureVec<T>
138where
139    T: Copy + Zeroize,
140    Vec<T>: std::ops::Index<U>,
141{
142    type Output = <Vec<T> as std::ops::Index<U>>::Output;
143
144    fn index(&self, index: U) -> &Self::Output {
145        std::ops::Index::index(&self.content, index)
146    }
147}
148
149// Borrowing
150impl<T> Borrow<[T]> for SecureVec<T>
151where
152    T: Copy + Zeroize,
153{
154    fn borrow(&self) -> &[T] {
155        self.content.borrow()
156    }
157}
158
159impl<T> BorrowMut<[T]> for SecureVec<T>
160where
161    T: Copy + Zeroize,
162{
163    fn borrow_mut(&mut self) -> &mut [T] {
164        self.content.borrow_mut()
165    }
166}
167
168// Overwrite memory with zeros when we're done
169impl<T> Drop for SecureVec<T>
170where
171    T: Copy + Zeroize,
172{
173    fn drop(&mut self) {
174        self.zero_out();
175        memlock::munlock(self.content.as_mut_ptr(), self.content.capacity());
176    }
177}
178
179// Make sure sensitive information is not logged accidentally
180impl<T> fmt::Debug for SecureVec<T>
181where
182    T: Copy + Zeroize,
183{
184    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
185        f.write_str("***SECRET***").map_err(|_| fmt::Error)
186    }
187}
188
189impl<T> fmt::Display for SecureVec<T>
190where
191    T: Copy + Zeroize,
192{
193    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
194        f.write_str("***SECRET***").map_err(|_| fmt::Error)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::{SecureBytes, SecureVec};
201
202    #[test]
203    fn test_basic() {
204        let my_sec = SecureBytes::from("hello");
205        assert_eq!(my_sec, SecureBytes::from("hello".to_string()));
206        assert_eq!(my_sec.unsecure(), b"hello");
207    }
208
209    #[test]
210    fn test_zero_out() {
211        let mut my_sec = SecureBytes::from("hello");
212        my_sec.zero_out();
213        // `zero_out` sets the `len` to 0, here we reset it to check that the bytes were
214        // zeroed
215        unsafe {
216            my_sec.content.set_len(5);
217        }
218        assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
219    }
220
221    #[test]
222    fn test_resize() {
223        let mut my_sec = SecureVec::from([0, 1]);
224        assert_eq!(my_sec.unsecure().len(), 2);
225        my_sec.resize(1, 0);
226        assert_eq!(my_sec.unsecure().len(), 1);
227        my_sec.resize(16, 2);
228        assert_eq!(
229            my_sec.unsecure(),
230            &[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
231        );
232    }
233
234    #[test]
235    fn test_comparison() {
236        assert_eq!(SecureBytes::from("hello"), SecureBytes::from("hello"));
237        assert_ne!(SecureBytes::from("hello"), SecureBytes::from("yolo"));
238        assert_ne!(SecureBytes::from("hello"), SecureBytes::from("olleh"));
239        assert_ne!(SecureBytes::from("hello"), SecureBytes::from("helloworld"));
240        assert_ne!(SecureBytes::from("hello"), SecureBytes::from(""));
241    }
242
243    #[test]
244    fn test_indexing() {
245        let string = SecureBytes::from("hello");
246        assert_eq!(string[0], b'h');
247        assert_eq!(&string[3..5], "lo".as_bytes());
248    }
249
250    #[test]
251    fn test_show() {
252        assert_eq!(
253            format!("{:?}", SecureBytes::from("hello")),
254            "***SECRET***".to_string()
255        );
256        assert_eq!(
257            format!("{}", SecureBytes::from("hello")),
258            "***SECRET***".to_string()
259        );
260    }
261
262    #[test]
263    fn test_comparison_zero_out_mb() {
264        let mbstring1 = SecureVec::from(vec![
265            'H' as u32,
266            'a' as u32,
267            'l' as u32,
268            'l' as u32,
269            'o' as u32,
270            ' ' as u32,
271            '🦄' as u32,
272            '!' as u32,
273        ]);
274        let mbstring2 = SecureVec::from(vec![
275            'H' as u32,
276            'a' as u32,
277            'l' as u32,
278            'l' as u32,
279            'o' as u32,
280            ' ' as u32,
281            '🦄' as u32,
282            '!' as u32,
283        ]);
284        let mbstring3 = SecureVec::from(vec![
285            '!' as u32,
286            '🦄' as u32,
287            ' ' as u32,
288            'o' as u32,
289            'l' as u32,
290            'l' as u32,
291            'a' as u32,
292            'H' as u32,
293        ]);
294        assert_eq!(mbstring1, mbstring2);
295        assert_ne!(mbstring1, mbstring3);
296
297        let mut mbstring = mbstring1.clone();
298        mbstring.zero_out();
299        // `zero_out` sets the `len` to 0, here we reset it to check that the bytes were
300        // zeroed
301        unsafe {
302            mbstring.content.set_len(8);
303        }
304        assert_eq!(mbstring.unsecure(), &[0u32; 8]);
305    }
306}