Skip to main content

securer_string/secure_types/
vec.rs

1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::str::FromStr;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10/// A data type suitable for storing sensitive information such as passwords and
11/// private keys in memory, that implements:
12///
13/// - Automatic zeroing in `Drop`
14/// - Constant time comparison in `PartialEq` (does not short circuit on the
15///   first different character; but terminates instantly if strings have
16///   different length)
17/// - Outputting `***SECRET***` to prevent leaking secrets into logs in
18///   `fmt::Debug` and `fmt::Display`
19/// - Automatic `mlock` to protect against leaking into swap (any unix)
20/// - Automatic `madvise(MADV_NOCORE/MADV_DONTDUMP)` to protect against leaking
21///   into core dumps (FreeBSD, DragonflyBSD, Linux)
22///
23/// `PartialEq` and `Eq` are only implemented when `T: ConstantTimeEq`. The
24/// safety of comparisons with respect to padding bytes depends on the
25/// `ConstantTimeEq` implementation of `T`.
26///
27/// Be careful with `SecureBytes::from`: if you have a borrowed string, it will
28/// be copied. Use `SecureBytes::new` if you have a `Vec<u8>`.
29pub struct SecureVec<T>
30where
31    T: Copy + Zeroize,
32{
33    pub(crate) content: Vec<T>,
34}
35
36/// Type alias for a vector that stores just bytes
37pub type SecureBytes = SecureVec<u8>;
38
39impl<T> SecureVec<T>
40where
41    T: Copy + Zeroize,
42{
43    pub fn new(mut cont: Vec<T>) -> Self {
44        memlock::mlock(cont.as_mut_ptr(), cont.capacity());
45        SecureVec { content: cont }
46    }
47
48    /// Borrow the contents of the string.
49    pub fn unsecure(&self) -> &[T] {
50        self.borrow()
51    }
52
53    /// Mutably borrow the contents of the string.
54    pub fn unsecure_mut(&mut self) -> &mut [T] {
55        self.borrow_mut()
56    }
57
58    /// Resizes the `SecureVec` in-place so that len is equal to `new_len`.
59    ///
60    /// If `new_len` is smaller the inner vector is truncated.
61    /// If `new_len` is larger the inner vector will grow, placing `value` in
62    /// all new cells.
63    ///
64    /// This ensures that the new memory region is secured if reallocation
65    /// occurs.
66    ///
67    /// Similar to [`Vec::resize`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.resize)
68    pub fn resize(&mut self, new_len: usize, value: T) {
69        // Trucnate if shorter or same length
70        if new_len <= self.content.len() {
71            self.content.truncate(new_len);
72            return;
73        }
74
75        // Allocate new vector, copy old data into it
76        let mut new_vec = vec![value; new_len];
77        memlock::mlock(new_vec.as_mut_ptr(), new_vec.capacity());
78        new_vec[0..self.content.len()].copy_from_slice(&self.content);
79
80        // Securely clear old vector, replace with new vector
81        self.zero_out();
82        memlock::munlock(self.content.as_mut_ptr(), self.content.capacity());
83        self.content = new_vec;
84    }
85
86    /// Overwrite the string with zeros. This is automatically called in the
87    /// destructor.
88    ///
89    /// This also sets the length to `0`.
90    pub fn zero_out(&mut self) {
91        self.content.zeroize()
92    }
93}
94
95impl<T: Copy + Zeroize> Clone for SecureVec<T> {
96    fn clone(&self) -> Self {
97        Self::new(self.content.clone())
98    }
99}
100
101impl<T: Copy + Zeroize + ConstantTimeEq> ConstantTimeEq for SecureVec<T> {
102    fn ct_eq(&self, other: &Self) -> subtle::Choice {
103        self.content.ct_eq(&other.content)
104    }
105}
106
107impl<T: Copy + Zeroize + ConstantTimeEq> PartialEq for SecureVec<T> {
108    fn eq(&self, other: &Self) -> bool {
109        self.ct_eq(other).into()
110    }
111}
112
113impl<T: Copy + Zeroize + ConstantTimeEq> Eq for SecureVec<T> {}
114
115// Creation
116impl<T, U> From<U> for SecureVec<T>
117where
118    U: Into<Vec<T>>,
119    T: Copy + Zeroize,
120{
121    fn from(s: U) -> SecureVec<T> {
122        SecureVec::new(s.into())
123    }
124}
125
126impl FromStr for SecureVec<u8> {
127    type Err = std::convert::Infallible;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        Ok(SecureVec::new(s.into()))
131    }
132}
133
134// Vec item indexing
135impl<T, U> std::ops::Index<U> for SecureVec<T>
136where
137    T: Copy + Zeroize,
138    Vec<T>: std::ops::Index<U>,
139{
140    type Output = <Vec<T> as std::ops::Index<U>>::Output;
141
142    fn index(&self, index: U) -> &Self::Output {
143        std::ops::Index::index(&self.content, index)
144    }
145}
146
147// Borrowing
148impl<T> Borrow<[T]> for SecureVec<T>
149where
150    T: Copy + Zeroize,
151{
152    fn borrow(&self) -> &[T] {
153        self.content.borrow()
154    }
155}
156
157impl<T> BorrowMut<[T]> for SecureVec<T>
158where
159    T: Copy + Zeroize,
160{
161    fn borrow_mut(&mut self) -> &mut [T] {
162        self.content.borrow_mut()
163    }
164}
165
166// Overwrite memory with zeros when we're done
167impl<T> Drop for SecureVec<T>
168where
169    T: Copy + Zeroize,
170{
171    fn drop(&mut self) {
172        self.zero_out();
173        memlock::munlock(self.content.as_mut_ptr(), self.content.capacity());
174    }
175}
176
177// Make sure sensitive information is not logged accidentally
178impl<T> fmt::Debug for SecureVec<T>
179where
180    T: Copy + Zeroize,
181{
182    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183        f.write_str("***SECRET***").map_err(|_| fmt::Error)
184    }
185}
186
187impl<T> fmt::Display for SecureVec<T>
188where
189    T: Copy + Zeroize,
190{
191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192        f.write_str("***SECRET***").map_err(|_| fmt::Error)
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::{SecureBytes, SecureVec};
199
200    #[test]
201    fn test_basic() {
202        let my_sec = SecureBytes::from("hello");
203        assert_eq!(my_sec, SecureBytes::from("hello".to_string()));
204        assert_eq!(my_sec.unsecure(), b"hello");
205    }
206
207    #[test]
208    fn test_zero_out() {
209        let mut my_sec = SecureBytes::from("hello");
210        my_sec.zero_out();
211        // `zero_out` sets the `len` to 0, here we reset it to check that the bytes were
212        // zeroed
213        unsafe {
214            my_sec.content.set_len(5);
215        }
216        assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
217    }
218
219    #[test]
220    fn test_resize() {
221        let mut my_sec = SecureVec::from([0, 1]);
222        assert_eq!(my_sec.unsecure().len(), 2);
223        my_sec.resize(1, 0);
224        assert_eq!(my_sec.unsecure().len(), 1);
225        my_sec.resize(16, 2);
226        assert_eq!(
227            my_sec.unsecure(),
228            &[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
229        );
230    }
231
232    #[test]
233    fn test_comparison() {
234        assert_eq!(SecureBytes::from("hello"), SecureBytes::from("hello"));
235        assert!(SecureBytes::from("hello") != SecureBytes::from("yolo"));
236        assert!(SecureBytes::from("hello") != SecureBytes::from("olleh"));
237        assert!(SecureBytes::from("hello") != SecureBytes::from("helloworld"));
238        assert!(SecureBytes::from("hello") != SecureBytes::from(""));
239    }
240
241    #[test]
242    fn test_indexing() {
243        let string = SecureBytes::from("hello");
244        assert_eq!(string[0], b'h');
245        assert_eq!(&string[3..5], "lo".as_bytes());
246    }
247
248    #[test]
249    fn test_show() {
250        assert_eq!(
251            format!("{:?}", SecureBytes::from("hello")),
252            "***SECRET***".to_string()
253        );
254        assert_eq!(
255            format!("{}", SecureBytes::from("hello")),
256            "***SECRET***".to_string()
257        );
258    }
259
260    #[test]
261    fn test_comparison_zero_out_mb() {
262        let mbstring1 = SecureVec::from(vec![
263            'H' as u32,
264            'a' as u32,
265            'l' as u32,
266            'l' as u32,
267            'o' as u32,
268            ' ' as u32,
269            '🦄' as u32,
270            '!' as u32,
271        ]);
272        let mbstring2 = SecureVec::from(vec![
273            'H' as u32,
274            'a' as u32,
275            'l' as u32,
276            'l' as u32,
277            'o' as u32,
278            ' ' as u32,
279            '🦄' as u32,
280            '!' as u32,
281        ]);
282        let mbstring3 = SecureVec::from(vec![
283            '!' as u32,
284            '🦄' as u32,
285            ' ' as u32,
286            'o' as u32,
287            'l' as u32,
288            'l' as u32,
289            'a' as u32,
290            'H' as u32,
291        ]);
292        assert!(mbstring1 == mbstring2);
293        assert!(mbstring1 != mbstring3);
294
295        let mut mbstring = mbstring1.clone();
296        mbstring.zero_out();
297        // `zero_out` sets the `len` to 0, here we reset it to check that the bytes were
298        // zeroed
299        unsafe {
300            mbstring.content.set_len(8);
301        }
302        assert_eq!(mbstring.unsecure(), &[0u32; 8]);
303    }
304}