securer_string/secure_types/
boxed.rs1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::mem::MaybeUninit;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10pub struct SecureBox<T>
23where
24 T: Copy,
25{
26 content: Option<Box<T>>,
29}
30
31impl<T> SecureBox<T>
32where
33 T: Copy,
34{
35 #[must_use]
36 pub fn new(mut cont: Box<T>) -> Self {
37 memlock::mlock(&raw mut *cont, 1);
38 SecureBox {
39 content: Some(cont),
40 }
41 }
42
43 #[must_use]
49 pub fn unsecure(&self) -> &T {
50 self.content
51 .as_deref()
52 .expect("SecureBox content accessed after drop")
53 }
54
55 pub fn unsecure_mut(&mut self) -> &mut T {
61 self.content
62 .as_deref_mut()
63 .expect("SecureBox content accessed after drop")
64 }
65}
66
67impl<T: Copy> Clone for SecureBox<T> {
68 fn clone(&self) -> Self {
69 Self::new(Box::new(*self.unsecure()))
70 }
71}
72
73impl<T: Copy + ConstantTimeEq> ConstantTimeEq for SecureBox<T> {
74 fn ct_eq(&self, other: &Self) -> subtle::Choice {
75 self.unsecure().ct_eq(other.unsecure())
76 }
77}
78
79impl<T: Copy + ConstantTimeEq> PartialEq for SecureBox<T> {
80 fn eq(&self, other: &Self) -> bool {
81 self.ct_eq(other).into()
82 }
83}
84
85impl<T: Copy + ConstantTimeEq> Eq for SecureBox<T> {}
86
87impl<T, U> std::ops::Index<U> for SecureBox<T>
89where
90 T: std::ops::Index<U> + Copy,
91{
92 type Output = <T as std::ops::Index<U>>::Output;
93
94 fn index(&self, index: U) -> &Self::Output {
95 std::ops::Index::index(self.unsecure(), index)
96 }
97}
98
99impl<T> Borrow<T> for SecureBox<T>
101where
102 T: Copy,
103{
104 fn borrow(&self) -> &T {
105 self.unsecure()
106 }
107}
108impl<T> BorrowMut<T> for SecureBox<T>
109where
110 T: Copy,
111{
112 fn borrow_mut(&mut self) -> &mut T {
113 self.unsecure_mut()
114 }
115}
116
117impl<T> Drop for SecureBox<T>
119where
120 T: Copy,
121{
122 fn drop(&mut self) {
123 let ptr = Box::into_raw(self.content.take().expect("SecureBox dropped twice"));
128
129 unsafe {
137 std::slice::from_raw_parts_mut::<MaybeUninit<u8>>(
138 ptr.cast::<MaybeUninit<u8>>(),
139 std::mem::size_of::<T>(),
140 )
141 .zeroize();
142 }
143
144 memlock::munlock(ptr, 1);
145
146 if std::mem::size_of::<T>() != 0 {
148 unsafe { std::alloc::dealloc(ptr.cast::<u8>(), std::alloc::Layout::new::<T>()) };
153 }
154 }
155}
156
157impl<T> fmt::Debug for SecureBox<T>
159where
160 T: Copy,
161{
162 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
163 f.debug_struct("SecureBox").finish_non_exhaustive()
164 }
165}
166
167impl<T> fmt::Display for SecureBox<T>
168where
169 T: Copy,
170{
171 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
172 f.write_str("***SECRET***").map_err(|_| fmt::Error)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use std::mem::MaybeUninit;
179
180 use zeroize::Zeroize;
181
182 use super::SecureBox;
183 use crate::test_utils::{PRIVATE_KEY_1, PRIVATE_KEY_2, Packed, Padded};
184
185 unsafe fn zero_out_secure_box<T>(secure_box: &mut SecureBox<T>)
191 where
192 T: Copy,
193 {
194 unsafe {
195 std::slice::from_raw_parts_mut::<MaybeUninit<u8>>(
199 std::ptr::from_mut::<T>(secure_box.unsecure_mut()).cast::<MaybeUninit<u8>>(),
200 std::mem::size_of::<T>(),
201 )
202 .zeroize();
203 }
204 }
205
206 #[test]
207 fn test_secure_box() {
208 let key_1 = SecureBox::new(Box::new(PRIVATE_KEY_1));
209 let key_2 = SecureBox::new(Box::new(PRIVATE_KEY_2));
210 let key_3 = SecureBox::new(Box::new(PRIVATE_KEY_1));
211 assert_eq!(key_1, key_1);
212 assert_ne!(key_1, key_2);
213 assert_ne!(key_2, key_3);
214 assert_eq!(key_1, key_3);
215
216 let mut final_key = key_1.clone();
217 unsafe {
218 zero_out_secure_box(&mut final_key);
219 }
220 assert_eq!(final_key.unsecure().0, [0; 32]);
221 }
222
223 #[test]
224 fn test_repr_c_with_padding() {
225 assert_eq!(std::mem::size_of::<Padded>(), 4); let sec_a = SecureBox::new(Box::new(Padded { x: 1, y: 2 }));
228 let sec_b = SecureBox::new(Box::new(Padded { x: 1, y: 2 }));
229 assert_eq!(sec_a, sec_b);
230
231 let sec_c = SecureBox::new(Box::new(Padded { x: 1, y: 3 }));
232 assert_ne!(sec_a, sec_c);
233
234 let sec_d = SecureBox::new(Box::new(Padded { x: 2, y: 2 }));
235 assert_ne!(sec_a, sec_d);
236 }
237
238 #[test]
239 fn test_repr_c_packed() {
240 assert_eq!(std::mem::size_of::<Packed>(), 3);
241
242 let sec_a = SecureBox::new(Box::new(Packed { x: 42, y: 1000 }));
243 let sec_b = SecureBox::new(Box::new(Packed { x: 42, y: 1000 }));
244 let sec_c = SecureBox::new(Box::new(Packed { x: 42, y: 1001 }));
245 let sec_d = SecureBox::new(Box::new(Packed { x: 43, y: 1000 }));
246
247 assert_eq!(sec_a, sec_b);
248 assert_ne!(sec_a, sec_c);
249 assert_ne!(sec_a, sec_d);
250 }
251}