securer_string/secure_types/
array.rs1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::str::FromStr;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::SecureBox;
9
10pub struct SecureArray<T, const LENGTH: usize>
27where
28 T: Copy + Zeroize,
29{
30 inner: SecureBox<[T; LENGTH]>,
31}
32
33impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
34where
35 T: Copy + Zeroize,
36{
37 #[must_use]
38 pub fn new(content: [T; LENGTH]) -> Self {
39 Self {
40 inner: SecureBox::new(Box::new(content)),
41 }
42 }
43
44 #[must_use]
46 pub fn unsecure(&self) -> &[T] {
47 self.borrow()
48 }
49
50 #[must_use]
52 pub fn unsecure_mut(&mut self) -> &mut [T] {
53 self.borrow_mut()
54 }
55
56 pub fn zero_out(&mut self) {
59 self.inner.unsecure_mut().zeroize();
60 }
61}
62
63impl<T: Copy + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
64 fn clone(&self) -> Self {
65 Self {
66 inner: self.inner.clone(),
67 }
68 }
69}
70
71impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> ConstantTimeEq
72 for SecureArray<T, LENGTH>
73{
74 fn ct_eq(&self, other: &Self) -> subtle::Choice {
75 self.unsecure().ct_eq(other.unsecure())
76 }
77}
78
79impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> PartialEq for SecureArray<T, LENGTH> {
80 fn eq(&self, other: &Self) -> bool {
81 self.ct_eq(other).into()
82 }
83}
84
85impl<T: Copy + Zeroize + ConstantTimeEq, const LENGTH: usize> Eq for SecureArray<T, LENGTH> {}
86
87impl<T, const LENGTH: usize> From<[T; LENGTH]> for SecureArray<T, LENGTH>
89where
90 T: Copy + Zeroize,
91{
92 fn from(s: [T; LENGTH]) -> Self {
93 Self::new(s)
94 }
95}
96
97impl<T, const LENGTH: usize> TryFrom<Vec<T>> for SecureArray<T, LENGTH>
98where
99 T: Copy + Zeroize,
100{
101 type Error = String;
102
103 fn try_from(s: Vec<T>) -> Result<Self, Self::Error> {
104 Ok(Self::new(s.try_into().map_err(|error: Vec<T>| {
105 format!(
106 "length mismatch: expected {LENGTH}, but got {}",
107 error.len()
108 )
109 })?))
110 }
111}
112
113impl<const LENGTH: usize> FromStr for SecureArray<u8, LENGTH> {
114 type Err = std::array::TryFromSliceError;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 Ok(SecureArray::new(s.as_bytes().try_into()?))
118 }
119}
120
121impl<T, U, const LENGTH: usize> std::ops::Index<U> for SecureArray<T, LENGTH>
123where
124 T: Copy + Zeroize,
125 [T; LENGTH]: std::ops::Index<U>,
126{
127 type Output = <[T; LENGTH] as std::ops::Index<U>>::Output;
128
129 fn index(&self, index: U) -> &Self::Output {
130 std::ops::Index::index(self.inner.unsecure(), index)
131 }
132}
133
134impl<T, const LENGTH: usize> Borrow<[T]> for SecureArray<T, LENGTH>
136where
137 T: Copy + Zeroize,
138{
139 fn borrow(&self) -> &[T] {
140 self.inner.unsecure().as_slice()
141 }
142}
143
144impl<T, const LENGTH: usize> BorrowMut<[T]> for SecureArray<T, LENGTH>
145where
146 T: Copy + Zeroize,
147{
148 fn borrow_mut(&mut self) -> &mut [T] {
149 self.inner.unsecure_mut().as_mut_slice()
150 }
151}
152
153impl<T, const LENGTH: usize> fmt::Debug for SecureArray<T, LENGTH>
155where
156 T: Copy + Zeroize,
157{
158 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159 f.debug_struct("SecureArray").finish_non_exhaustive()
160 }
161}
162
163impl<T, const LENGTH: usize> fmt::Display for SecureArray<T, LENGTH>
164where
165 T: Copy + Zeroize,
166{
167 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
168 f.write_str("***SECRET***").map_err(|_| fmt::Error)
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::str::FromStr;
175
176 use super::SecureArray;
177
178 #[test]
179 fn test_basic() {
180 let my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
181 assert_eq!(my_sec, SecureArray::from_str("hello").unwrap());
182 assert_eq!(my_sec.unsecure(), b"hello");
183 }
184
185 #[test]
186 fn test_zero_out() {
187 let mut my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
188 my_sec.zero_out();
189 assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
190 }
191
192 #[test]
193 fn test_comparison() {
194 assert_eq!(
195 SecureArray::<_, 5>::from_str("hello").unwrap(),
196 SecureArray::from_str("hello").unwrap()
197 );
198 assert_ne!(
199 SecureArray::<_, 5>::from_str("hello").unwrap(),
200 SecureArray::from_str("olleh").unwrap()
201 );
202 }
203
204 #[test]
205 fn test_indexing() {
206 let string: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
207 assert_eq!(string[0], b'h');
208 assert_eq!(&string[3..5], "lo".as_bytes());
209 }
210
211 #[test]
212 fn test_show() {
213 assert_eq!(
214 format!("{:?}", SecureArray::<_, 5>::from_str("hello").unwrap()),
215 "SecureArray { .. }".to_string()
216 );
217 assert_eq!(
218 format!("{}", SecureArray::<_, 5>::from_str("hello").unwrap()),
219 "***SECRET***".to_string()
220 );
221 }
222
223 #[test]
224 fn test_move_keeps_contents() {
225 fn make() -> SecureArray<u8, 5> {
228 SecureArray::from_str("hello").unwrap()
229 }
230 let moved = make();
231 let v = [moved];
232 assert_eq!(v[0].unsecure(), b"hello");
233 }
234
235 #[test]
236 fn test_comparison_zero_out_multibyte() {
237 let data1 = SecureArray::from([
238 'H' as u32,
239 'a' as u32,
240 'l' as u32,
241 'l' as u32,
242 'o' as u32,
243 ' ' as u32,
244 '🦄' as u32,
245 '!' as u32,
246 ]);
247 let data2 = SecureArray::from([
248 'H' as u32,
249 'a' as u32,
250 'l' as u32,
251 'l' as u32,
252 'o' as u32,
253 ' ' as u32,
254 '🦄' as u32,
255 '!' as u32,
256 ]);
257 let data3 = SecureArray::from([
258 '!' as u32,
259 '🦄' as u32,
260 ' ' as u32,
261 'o' as u32,
262 'l' as u32,
263 'l' as u32,
264 'a' as u32,
265 'H' as u32,
266 ]);
267 assert_eq!(data1, data2);
268 assert_ne!(data1, data3);
269
270 let mut zeroed = data1.clone();
271 zeroed.zero_out();
272 assert_eq!(zeroed.unsecure(), &[0u32; 8]);
273 }
274}