secure_string/secure_types/
array.rs1use core::fmt;
2use std::{
3 borrow::{Borrow, BorrowMut},
4 str::FromStr,
5};
6
7use zeroize::Zeroize;
8
9use crate::secure_utils::memlock;
10
11#[derive(Eq, PartialEq, PartialOrd, Ord, Hash)]
21pub struct SecureArray<T, const LENGTH: usize>
22where
23 T: Copy + Zeroize,
24{
25 pub(crate) content: [T; LENGTH],
26}
27
28impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
29where
30 T: Copy + Zeroize,
31{
32 pub fn new(mut content: [T; LENGTH]) -> Self {
33 memlock::mlock(content.as_mut_ptr(), content.len());
34 Self { content }
35 }
36
37 pub fn unsecure(&self) -> &[T] {
39 self.borrow()
40 }
41
42 pub fn unsecure_mut(&mut self) -> &mut [T] {
44 self.borrow_mut()
45 }
46
47 pub fn zero_out(&mut self) {
49 self.content.zeroize()
50 }
51}
52
53impl<T: Copy + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
54 fn clone(&self) -> Self {
55 Self::new(self.content)
56 }
57}
58
59impl<T, const LENGTH: usize> From<[T; LENGTH]> for SecureArray<T, LENGTH>
61where
62 T: Copy + Zeroize,
63{
64 fn from(s: [T; LENGTH]) -> Self {
65 Self::new(s)
66 }
67}
68
69impl<T, const LENGTH: usize> TryFrom<Vec<T>> for SecureArray<T, LENGTH>
70where
71 T: Copy + Zeroize,
72{
73 type Error = String;
74
75 fn try_from(s: Vec<T>) -> Result<Self, Self::Error> {
76 Ok(Self::new(s.try_into().map_err(|error: Vec<T>| {
77 format!("length mismatch: expected {LENGTH}, but got {}", error.len())
78 })?))
79 }
80}
81
82impl<const LENGTH: usize> FromStr for SecureArray<u8, LENGTH> {
83 type Err = std::array::TryFromSliceError;
84
85 fn from_str(s: &str) -> Result<Self, Self::Err> {
86 Ok(SecureArray::new(s.as_bytes().try_into()?))
87 }
88}
89
90impl<T, U, const LENGTH: usize> std::ops::Index<U> for SecureArray<T, LENGTH>
92where
93 T: Copy + Zeroize,
94 [T; LENGTH]: std::ops::Index<U>,
95{
96 type Output = <[T; LENGTH] as std::ops::Index<U>>::Output;
97
98 fn index(&self, index: U) -> &Self::Output {
99 std::ops::Index::index(&self.content, index)
100 }
101}
102
103impl<T, const LENGTH: usize> Borrow<[T]> for SecureArray<T, LENGTH>
105where
106 T: Copy + Zeroize,
107{
108 fn borrow(&self) -> &[T] {
109 self.content.borrow()
110 }
111}
112
113impl<T, const LENGTH: usize> BorrowMut<[T]> for SecureArray<T, LENGTH>
114where
115 T: Copy + Zeroize,
116{
117 fn borrow_mut(&mut self) -> &mut [T] {
118 self.content.borrow_mut()
119 }
120}
121
122impl<T, const LENGTH: usize> Drop for SecureArray<T, LENGTH>
124where
125 T: Copy + Zeroize,
126{
127 fn drop(&mut self) {
128 self.zero_out();
129 memlock::munlock(self.content.as_mut_ptr(), self.content.len());
130 }
131}
132
133impl<T, const LENGTH: usize> fmt::Debug for SecureArray<T, LENGTH>
135where
136 T: Copy + Zeroize,
137{
138 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
139 f.write_str("***SECRET***").map_err(|_| fmt::Error)
140 }
141}
142
143impl<T, const LENGTH: usize> fmt::Display for SecureArray<T, LENGTH>
144where
145 T: Copy + Zeroize,
146{
147 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
148 f.write_str("***SECRET***").map_err(|_| fmt::Error)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use std::str::FromStr;
155
156 use super::SecureArray;
157
158 #[test]
159 fn test_basic() {
160 let my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
161 assert_eq!(my_sec, SecureArray::from_str("hello").unwrap());
162 assert_eq!(my_sec.unsecure(), b"hello");
163 }
164
165 #[test]
166 #[cfg_attr(feature = "pre", pre::pre)]
167 fn test_zero_out() {
168 let mut my_sec: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
169 my_sec.zero_out();
170 assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
171 }
172
173 #[test]
174 fn test_comparison() {
175 assert_eq!(SecureArray::<_, 5>::from_str("hello").unwrap(), SecureArray::from_str("hello").unwrap());
176 assert_ne!(SecureArray::<_, 5>::from_str("hello").unwrap(), SecureArray::from_str("olleh").unwrap());
177 }
178
179 #[test]
180 fn test_indexing() {
181 let string: SecureArray<_, 5> = SecureArray::from_str("hello").unwrap();
182 assert_eq!(string[0], b'h');
183 assert_eq!(&string[3..5], "lo".as_bytes());
184 }
185
186 #[test]
187 fn test_show() {
188 assert_eq!(format!("{:?}", SecureArray::<_, 5>::from_str("hello").unwrap()), "***SECRET***".to_string());
189 assert_eq!(format!("{}", SecureArray::<_, 5>::from_str("hello").unwrap()), "***SECRET***".to_string());
190 }
191
192 #[test]
193 #[cfg_attr(feature = "pre", pre::pre)]
194 fn test_comparison_zero_out_mb() {
195 let mbstring1 = SecureArray::from(['H', 'a', 'l', 'l', 'o', ' ', '🦄', '!']);
196 let mbstring2 = SecureArray::from(['H', 'a', 'l', 'l', 'o', ' ', '🦄', '!']);
197 let mbstring3 = SecureArray::from(['!', '🦄', ' ', 'o', 'l', 'l', 'a', 'H']);
198 assert!(mbstring1 == mbstring2);
199 assert!(mbstring1 != mbstring3);
200
201 let mut mbstring = mbstring1.clone();
202 mbstring.zero_out();
203 assert_eq!(mbstring.unsecure(), &['\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0']);
204 }
205}