Skip to main content

securer_string/secure_types/
array.rs

1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::str::FromStr;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10/// A data type suitable for storing sensitive information such as passwords and
11/// private keys in memory, that implements:
12///
13/// - Automatic zeroing in `Drop`
14/// - Constant time comparison in `PartialEq` (does not short circuit on the
15///   first different character; but terminates instantly if strings have
16///   different length)
17/// - Outputting `***SECRET***` to prevent leaking secrets into logs in
18///   `fmt::Debug` and `fmt::Display`
19/// - Automatic `mlock` to protect against leaking into swap (any unix)
20/// - Automatic `madvise(MADV_NOCORE/MADV_DONTDUMP)` to protect against leaking
21///   into core dumps (FreeBSD, DragonflyBSD, Linux)
22pub struct SecureArray<T, const LENGTH: usize>
23where
24    T: Copy + Zeroize,
25{
26    pub(crate) content: [T; LENGTH],
27}
28
29impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
30where
31    T: Copy + Zeroize,
32{
33    pub fn new(mut content: [T; LENGTH]) -> Self {
34        memlock::mlock(content.as_mut_ptr(), content.len());
35        Self { content }
36    }
37
38    /// Borrow the contents of the string.
39    pub fn unsecure(&self) -> &[T] {
40        self.borrow()
41    }
42
43    /// Mutably borrow the contents of the string.
44    pub fn unsecure_mut(&mut self) -> &mut [T] {
45        self.borrow_mut()
46    }
47
48    /// Overwrite the string with zeros. This is automatically called in the
49    /// destructor.
50    pub fn zero_out(&mut self) {
51        self.content.zeroize();
52    }
53}
54
55impl<T: Copy + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
56    fn clone(&self) -> Self {
57        Self::new(self.content)
58    }
59}
60
61impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> ConstantTimeEq
62    for SecureArray<T, LENGTH>
63{
64    fn ct_eq(&self, other: &Self) -> subtle::Choice {
65        self.content.as_slice().ct_eq(other.content.as_slice())
66    }
67}
68
69impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> PartialEq for SecureArray<T, LENGTH> {
70    fn eq(&self, other: &Self) -> bool {
71        self.ct_eq(other).into()
72    }
73}
74
75impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> Eq for SecureArray<T, LENGTH> {}
76
77// Creation
78impl<T, const LENGTH: usize> From<[T; LENGTH]> for SecureArray<T, LENGTH>
79where
80    T: Copy + Zeroize,
81{
82    fn from(s: [T; LENGTH]) -> Self {
83        Self::new(s)
84    }
85}
86
87impl<T, const LENGTH: usize> TryFrom<Vec<T>> for SecureArray<T, LENGTH>
88where
89    T: Copy + Zeroize,
90{
91    type Error = String;
92
93    fn try_from(s: Vec<T>) -> Result<Self, Self::Error> {
94        Ok(Self::new(s.try_into().map_err(|error: Vec<T>| {
95            format!(
96                "length mismatch: expected {LENGTH}, but got {}",
97                error.len()
98            )
99        })?))
100    }
101}
102
103impl<const LENGTH: usize> FromStr for SecureArray<u8, LENGTH> {
104    type Err = std::array::TryFromSliceError;
105
106    fn from_str(s: &str) -> Result<Self, Self::Err> {
107        Ok(SecureArray::new(s.as_bytes().try_into()?))
108    }
109}
110
111// Array item indexing
112impl<T, U, const LENGTH: usize> std::ops::Index<U> for SecureArray<T, LENGTH>
113where
114    T: Copy + Zeroize,
115    [T; LENGTH]: std::ops::Index<U>,
116{
117    type Output = <[T; LENGTH] as std::ops::Index<U>>::Output;
118
119    fn index(&self, index: U) -> &Self::Output {
120        std::ops::Index::index(&self.content, index)
121    }
122}
123
124// Borrowing
125impl<T, const LENGTH: usize> Borrow<[T]> for SecureArray<T, LENGTH>
126where
127    T: Copy + Zeroize,
128{
129    fn borrow(&self) -> &[T] {
130        self.content.borrow()
131    }
132}
133
134impl<T, const LENGTH: usize> BorrowMut<[T]> for SecureArray<T, LENGTH>
135where
136    T: Copy + Zeroize,
137{
138    fn borrow_mut(&mut self) -> &mut [T] {
139        self.content.borrow_mut()
140    }
141}
142
143// Overwrite memory with zeros when we're done
144impl<T, const LENGTH: usize> Drop for SecureArray<T, LENGTH>
145where
146    T: Copy + Zeroize,
147{
148    fn drop(&mut self) {
149        self.zero_out();
150        memlock::munlock(self.content.as_mut_ptr(), self.content.len());
151    }
152}
153
154// Make sure sensitive information is not logged accidentally
155impl<T, const LENGTH: usize> fmt::Debug for SecureArray<T, LENGTH>
156where
157    T: Copy + Zeroize,
158{
159    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160        f.debug_struct("SecureArray").finish_non_exhaustive()
161    }
162}
163
164impl<T, const LENGTH: usize> fmt::Display for SecureArray<T, LENGTH>
165where
166    T: Copy + Zeroize,
167{
168    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
169        f.write_str("***SECRET***").map_err(|_| fmt::Error)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use std::str::FromStr;
176
177    use super::SecureArray;
178
179    #[test]
180    fn test_basic() {
181        let my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
182        assert_eq!(my_sec, SecureArray::from_str("hello").unwrap());
183        assert_eq!(my_sec.unsecure(), b"hello");
184    }
185
186    #[test]
187    fn test_zero_out() {
188        let mut my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
189        my_sec.zero_out();
190        assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
191    }
192
193    #[test]
194    fn test_comparison() {
195        assert_eq!(
196            SecureArray::<_, 5>::from_str("hello").unwrap(),
197            SecureArray::from_str("hello").unwrap()
198        );
199        assert_ne!(
200            SecureArray::<_, 5>::from_str("hello").unwrap(),
201            SecureArray::from_str("olleh").unwrap()
202        );
203    }
204
205    #[test]
206    fn test_indexing() {
207        let string: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
208        assert_eq!(string[0], b'h');
209        assert_eq!(&string[3..5], "lo".as_bytes());
210    }
211
212    #[test]
213    fn test_show() {
214        assert_eq!(
215            format!("{:?}", SecureArray::<_, 5>::from_str("hello").unwrap()),
216            "SecureArray { .. }".to_string()
217        );
218        assert_eq!(
219            format!("{}", SecureArray::<_, 5>::from_str("hello").unwrap()),
220            "***SECRET***".to_string()
221        );
222    }
223
224    #[test]
225    fn test_comparison_zero_out_multibyte() {
226        let data1 = SecureArray::from([
227            'H' as u32,
228            'a' as u32,
229            'l' as u32,
230            'l' as u32,
231            'o' as u32,
232            ' ' as u32,
233            '🦄' as u32,
234            '!' as u32,
235        ]);
236        let data2 = SecureArray::from([
237            'H' as u32,
238            'a' as u32,
239            'l' as u32,
240            'l' as u32,
241            'o' as u32,
242            ' ' as u32,
243            '🦄' as u32,
244            '!' as u32,
245        ]);
246        let data3 = SecureArray::from([
247            '!' as u32,
248            '🦄' as u32,
249            ' ' as u32,
250            'o' as u32,
251            'l' as u32,
252            'l' as u32,
253            'a' as u32,
254            'H' as u32,
255        ]);
256        assert_eq!(data1, data2);
257        assert_ne!(data1, data3);
258
259        let mut zeroed = data1.clone();
260        zeroed.zero_out();
261        assert_eq!(zeroed.unsecure(), &[0u32; 8]);
262    }
263}