securer_string/secure_types/
array.rs1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::str::FromStr;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10pub struct SecureArray<T, const LENGTH: usize>
23where
24 T: Copy + Zeroize,
25{
26 pub(crate) content: [T; LENGTH],
27}
28
29impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
30where
31 T: Copy + Zeroize,
32{
33 pub fn new(mut content: [T; LENGTH]) -> Self {
34 memlock::mlock(content.as_mut_ptr(), content.len());
35 Self { content }
36 }
37
38 pub fn unsecure(&self) -> &[T] {
40 self.borrow()
41 }
42
43 pub fn unsecure_mut(&mut self) -> &mut [T] {
45 self.borrow_mut()
46 }
47
48 pub fn zero_out(&mut self) {
51 self.content.zeroize();
52 }
53}
54
55impl<T: Copy + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
56 fn clone(&self) -> Self {
57 Self::new(self.content)
58 }
59}
60
61impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> ConstantTimeEq
62 for SecureArray<T, LENGTH>
63{
64 fn ct_eq(&self, other: &Self) -> subtle::Choice {
65 self.content.as_slice().ct_eq(other.content.as_slice())
66 }
67}
68
69impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> PartialEq for SecureArray<T, LENGTH> {
70 fn eq(&self, other: &Self) -> bool {
71 self.ct_eq(other).into()
72 }
73}
74
75impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> Eq for SecureArray<T, LENGTH> {}
76
77impl<T, const LENGTH: usize> From<[T; LENGTH]> for SecureArray<T, LENGTH>
79where
80 T: Copy + Zeroize,
81{
82 fn from(s: [T; LENGTH]) -> Self {
83 Self::new(s)
84 }
85}
86
87impl<T, const LENGTH: usize> TryFrom<Vec<T>> for SecureArray<T, LENGTH>
88where
89 T: Copy + Zeroize,
90{
91 type Error = String;
92
93 fn try_from(s: Vec<T>) -> Result<Self, Self::Error> {
94 Ok(Self::new(s.try_into().map_err(|error: Vec<T>| {
95 format!(
96 "length mismatch: expected {LENGTH}, but got {}",
97 error.len()
98 )
99 })?))
100 }
101}
102
103impl<const LENGTH: usize> FromStr for SecureArray<u8, LENGTH> {
104 type Err = std::array::TryFromSliceError;
105
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 Ok(SecureArray::new(s.as_bytes().try_into()?))
108 }
109}
110
111impl<T, U, const LENGTH: usize> std::ops::Index<U> for SecureArray<T, LENGTH>
113where
114 T: Copy + Zeroize,
115 [T; LENGTH]: std::ops::Index<U>,
116{
117 type Output = <[T; LENGTH] as std::ops::Index<U>>::Output;
118
119 fn index(&self, index: U) -> &Self::Output {
120 std::ops::Index::index(&self.content, index)
121 }
122}
123
124impl<T, const LENGTH: usize> Borrow<[T]> for SecureArray<T, LENGTH>
126where
127 T: Copy + Zeroize,
128{
129 fn borrow(&self) -> &[T] {
130 self.content.borrow()
131 }
132}
133
134impl<T, const LENGTH: usize> BorrowMut<[T]> for SecureArray<T, LENGTH>
135where
136 T: Copy + Zeroize,
137{
138 fn borrow_mut(&mut self) -> &mut [T] {
139 self.content.borrow_mut()
140 }
141}
142
143impl<T, const LENGTH: usize> Drop for SecureArray<T, LENGTH>
145where
146 T: Copy + Zeroize,
147{
148 fn drop(&mut self) {
149 self.zero_out();
150 memlock::munlock(self.content.as_mut_ptr(), self.content.len());
151 }
152}
153
154impl<T, const LENGTH: usize> fmt::Debug for SecureArray<T, LENGTH>
156where
157 T: Copy + Zeroize,
158{
159 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160 f.debug_struct("SecureArray").finish_non_exhaustive()
161 }
162}
163
164impl<T, const LENGTH: usize> fmt::Display for SecureArray<T, LENGTH>
165where
166 T: Copy + Zeroize,
167{
168 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
169 f.write_str("***SECRET***").map_err(|_| fmt::Error)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use std::str::FromStr;
176
177 use super::SecureArray;
178
179 #[test]
180 fn test_basic() {
181 let my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
182 assert_eq!(my_sec, SecureArray::from_str("hello").unwrap());
183 assert_eq!(my_sec.unsecure(), b"hello");
184 }
185
186 #[test]
187 fn test_zero_out() {
188 let mut my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
189 my_sec.zero_out();
190 assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
191 }
192
193 #[test]
194 fn test_comparison() {
195 assert_eq!(
196 SecureArray::<_, 5>::from_str("hello").unwrap(),
197 SecureArray::from_str("hello").unwrap()
198 );
199 assert_ne!(
200 SecureArray::<_, 5>::from_str("hello").unwrap(),
201 SecureArray::from_str("olleh").unwrap()
202 );
203 }
204
205 #[test]
206 fn test_indexing() {
207 let string: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
208 assert_eq!(string[0], b'h');
209 assert_eq!(&string[3..5], "lo".as_bytes());
210 }
211
212 #[test]
213 fn test_show() {
214 assert_eq!(
215 format!("{:?}", SecureArray::<_, 5>::from_str("hello").unwrap()),
216 "SecureArray { .. }".to_string()
217 );
218 assert_eq!(
219 format!("{}", SecureArray::<_, 5>::from_str("hello").unwrap()),
220 "***SECRET***".to_string()
221 );
222 }
223
224 #[test]
225 fn test_comparison_zero_out_multibyte() {
226 let data1 = SecureArray::from([
227 'H' as u32,
228 'a' as u32,
229 'l' as u32,
230 'l' as u32,
231 'o' as u32,
232 ' ' as u32,
233 '🦄' as u32,
234 '!' as u32,
235 ]);
236 let data2 = SecureArray::from([
237 'H' as u32,
238 'a' as u32,
239 'l' as u32,
240 'l' as u32,
241 'o' as u32,
242 ' ' as u32,
243 '🦄' as u32,
244 '!' as u32,
245 ]);
246 let data3 = SecureArray::from([
247 '!' as u32,
248 '🦄' as u32,
249 ' ' as u32,
250 'o' as u32,
251 'l' as u32,
252 'l' as u32,
253 'a' as u32,
254 'H' as u32,
255 ]);
256 assert_eq!(data1, data2);
257 assert_ne!(data1, data3);
258
259 let mut zeroed = data1.clone();
260 zeroed.zero_out();
261 assert_eq!(zeroed.unsecure(), &[0u32; 8]);
262 }
263}