securer_string/secure_types/
vec.rs1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::str::FromStr;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10pub struct SecureVec<T>
30where
31 T: Copy + Zeroize,
32{
33 pub(crate) content: Vec<T>,
34}
35
36pub type SecureBytes = SecureVec<u8>;
38
39impl<T> SecureVec<T>
40where
41 T: Copy + Zeroize,
42{
43 pub fn new(mut cont: Vec<T>) -> Self {
44 memlock::mlock(cont.as_mut_ptr(), cont.capacity());
45 SecureVec { content: cont }
46 }
47
48 pub fn unsecure(&self) -> &[T] {
50 self.borrow()
51 }
52
53 pub fn unsecure_mut(&mut self) -> &mut [T] {
55 self.borrow_mut()
56 }
57
58 pub fn resize(&mut self, new_len: usize, value: T) {
69 if new_len <= self.content.len() {
71 self.content.truncate(new_len);
72 return;
73 }
74
75 let mut new_vec = vec![value; new_len];
77 memlock::mlock(new_vec.as_mut_ptr(), new_vec.capacity());
78 new_vec[0..self.content.len()].copy_from_slice(&self.content);
79
80 self.zero_out();
82 memlock::munlock(self.content.as_mut_ptr(), self.content.capacity());
83 self.content = new_vec;
84 }
85
86 pub fn zero_out(&mut self) {
91 self.content.zeroize()
92 }
93}
94
95impl<T: Copy + Zeroize> Clone for SecureVec<T> {
96 fn clone(&self) -> Self {
97 Self::new(self.content.clone())
98 }
99}
100
101impl<T: Copy + Zeroize + ConstantTimeEq> ConstantTimeEq for SecureVec<T> {
102 fn ct_eq(&self, other: &Self) -> subtle::Choice {
103 self.content.ct_eq(&other.content)
104 }
105}
106
107impl<T: Copy + Zeroize + ConstantTimeEq> PartialEq for SecureVec<T> {
108 fn eq(&self, other: &Self) -> bool {
109 self.ct_eq(other).into()
110 }
111}
112
113impl<T: Copy + Zeroize + ConstantTimeEq> Eq for SecureVec<T> {}
114
115impl<T, U> From<U> for SecureVec<T>
117where
118 U: Into<Vec<T>>,
119 T: Copy + Zeroize,
120{
121 fn from(s: U) -> SecureVec<T> {
122 SecureVec::new(s.into())
123 }
124}
125
126impl FromStr for SecureVec<u8> {
127 type Err = std::convert::Infallible;
128
129 fn from_str(s: &str) -> Result<Self, Self::Err> {
130 Ok(SecureVec::new(s.into()))
131 }
132}
133
134impl<T, U> std::ops::Index<U> for SecureVec<T>
136where
137 T: Copy + Zeroize,
138 Vec<T>: std::ops::Index<U>,
139{
140 type Output = <Vec<T> as std::ops::Index<U>>::Output;
141
142 fn index(&self, index: U) -> &Self::Output {
143 std::ops::Index::index(&self.content, index)
144 }
145}
146
147impl<T> Borrow<[T]> for SecureVec<T>
149where
150 T: Copy + Zeroize,
151{
152 fn borrow(&self) -> &[T] {
153 self.content.borrow()
154 }
155}
156
157impl<T> BorrowMut<[T]> for SecureVec<T>
158where
159 T: Copy + Zeroize,
160{
161 fn borrow_mut(&mut self) -> &mut [T] {
162 self.content.borrow_mut()
163 }
164}
165
166impl<T> Drop for SecureVec<T>
168where
169 T: Copy + Zeroize,
170{
171 fn drop(&mut self) {
172 self.zero_out();
173 memlock::munlock(self.content.as_mut_ptr(), self.content.capacity());
174 }
175}
176
177impl<T> fmt::Debug for SecureVec<T>
179where
180 T: Copy + Zeroize,
181{
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 f.write_str("***SECRET***").map_err(|_| fmt::Error)
184 }
185}
186
187impl<T> fmt::Display for SecureVec<T>
188where
189 T: Copy + Zeroize,
190{
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 f.write_str("***SECRET***").map_err(|_| fmt::Error)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::{SecureBytes, SecureVec};
199
200 #[test]
201 fn test_basic() {
202 let my_sec = SecureBytes::from("hello");
203 assert_eq!(my_sec, SecureBytes::from("hello".to_string()));
204 assert_eq!(my_sec.unsecure(), b"hello");
205 }
206
207 #[test]
208 fn test_zero_out() {
209 let mut my_sec = SecureBytes::from("hello");
210 my_sec.zero_out();
211 unsafe {
214 my_sec.content.set_len(5);
215 }
216 assert_eq!(my_sec.unsecure(), b"\x00\x00\x00\x00\x00");
217 }
218
219 #[test]
220 fn test_resize() {
221 let mut my_sec = SecureVec::from([0, 1]);
222 assert_eq!(my_sec.unsecure().len(), 2);
223 my_sec.resize(1, 0);
224 assert_eq!(my_sec.unsecure().len(), 1);
225 my_sec.resize(16, 2);
226 assert_eq!(
227 my_sec.unsecure(),
228 &[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
229 );
230 }
231
232 #[test]
233 fn test_comparison() {
234 assert_eq!(SecureBytes::from("hello"), SecureBytes::from("hello"));
235 assert!(SecureBytes::from("hello") != SecureBytes::from("yolo"));
236 assert!(SecureBytes::from("hello") != SecureBytes::from("olleh"));
237 assert!(SecureBytes::from("hello") != SecureBytes::from("helloworld"));
238 assert!(SecureBytes::from("hello") != SecureBytes::from(""));
239 }
240
241 #[test]
242 fn test_indexing() {
243 let string = SecureBytes::from("hello");
244 assert_eq!(string[0], b'h');
245 assert_eq!(&string[3..5], "lo".as_bytes());
246 }
247
248 #[test]
249 fn test_show() {
250 assert_eq!(
251 format!("{:?}", SecureBytes::from("hello")),
252 "***SECRET***".to_string()
253 );
254 assert_eq!(
255 format!("{}", SecureBytes::from("hello")),
256 "***SECRET***".to_string()
257 );
258 }
259
260 #[test]
261 fn test_comparison_zero_out_mb() {
262 let mbstring1 = SecureVec::from(vec![
263 'H' as u32,
264 'a' as u32,
265 'l' as u32,
266 'l' as u32,
267 'o' as u32,
268 ' ' as u32,
269 '🦄' as u32,
270 '!' as u32,
271 ]);
272 let mbstring2 = SecureVec::from(vec![
273 'H' as u32,
274 'a' as u32,
275 'l' as u32,
276 'l' as u32,
277 'o' as u32,
278 ' ' as u32,
279 '🦄' as u32,
280 '!' as u32,
281 ]);
282 let mbstring3 = SecureVec::from(vec![
283 '!' as u32,
284 '🦄' as u32,
285 ' ' as u32,
286 'o' as u32,
287 'l' as u32,
288 'l' as u32,
289 'a' as u32,
290 'H' as u32,
291 ]);
292 assert!(mbstring1 == mbstring2);
293 assert!(mbstring1 != mbstring3);
294
295 let mut mbstring = mbstring1.clone();
296 mbstring.zero_out();
297 unsafe {
300 mbstring.content.set_len(8);
301 }
302 assert_eq!(mbstring.unsecure(), &[0u32; 8]);
303 }
304}