securer_string/secure_types/
boxed.rs1use core::fmt;
2use std::borrow::{Borrow, BorrowMut};
3use std::mem::MaybeUninit;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8use crate::secure_utils::memlock;
9
10pub struct SecureBox<T>
23where
24 T: Copy,
25{
26 content: Option<Box<T>>,
29}
30
31impl<T> SecureBox<T>
32where
33 T: Copy,
34{
35 pub fn new(mut cont: Box<T>) -> Self {
36 memlock::mlock(&mut *cont, 1);
37 SecureBox {
38 content: Some(cont),
39 }
40 }
41
42 pub fn unsecure(&self) -> &T {
44 self.content
45 .as_deref()
46 .expect("SecureBox content accessed after drop")
47 }
48
49 pub fn unsecure_mut(&mut self) -> &mut T {
51 self.content
52 .as_deref_mut()
53 .expect("SecureBox content accessed after drop")
54 }
55}
56
57impl<T: Copy> Clone for SecureBox<T> {
58 fn clone(&self) -> Self {
59 Self::new(Box::new(*self.unsecure()))
60 }
61}
62
63impl<T: Copy + ConstantTimeEq> ConstantTimeEq for SecureBox<T> {
64 fn ct_eq(&self, other: &Self) -> subtle::Choice {
65 self.unsecure().ct_eq(other.unsecure())
66 }
67}
68
69impl<T: Copy + ConstantTimeEq> PartialEq for SecureBox<T> {
70 fn eq(&self, other: &Self) -> bool {
71 self.ct_eq(other).into()
72 }
73}
74
75impl<T: Copy + ConstantTimeEq> Eq for SecureBox<T> {}
76
77impl<T, U> std::ops::Index<U> for SecureBox<T>
79where
80 T: std::ops::Index<U> + Copy,
81{
82 type Output = <T as std::ops::Index<U>>::Output;
83
84 fn index(&self, index: U) -> &Self::Output {
85 std::ops::Index::index(self.unsecure(), index)
86 }
87}
88
89impl<T> Borrow<T> for SecureBox<T>
91where
92 T: Copy,
93{
94 fn borrow(&self) -> &T {
95 self.unsecure()
96 }
97}
98impl<T> BorrowMut<T> for SecureBox<T>
99where
100 T: Copy,
101{
102 fn borrow_mut(&mut self) -> &mut T {
103 self.unsecure_mut()
104 }
105}
106
107impl<T> Drop for SecureBox<T>
109where
110 T: Copy,
111{
112 fn drop(&mut self) {
113 let ptr = Box::into_raw(self.content.take().expect("SecureBox dropped twice"));
118
119 unsafe {
127 std::slice::from_raw_parts_mut::<MaybeUninit<u8>>(
128 ptr as *mut MaybeUninit<u8>,
129 std::mem::size_of::<T>(),
130 )
131 .zeroize();
132 }
133
134 memlock::munlock(ptr, 1);
135
136 if std::mem::size_of::<T>() != 0 {
138 unsafe { std::alloc::dealloc(ptr as *mut u8, std::alloc::Layout::new::<T>()) };
143 }
144 }
145}
146
147impl<T> fmt::Debug for SecureBox<T>
149where
150 T: Copy,
151{
152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153 f.write_str("***SECRET***").map_err(|_| fmt::Error)
154 }
155}
156
157impl<T> fmt::Display for SecureBox<T>
158where
159 T: Copy,
160{
161 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
162 f.write_str("***SECRET***").map_err(|_| fmt::Error)
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use std::mem::MaybeUninit;
169
170 use zeroize::Zeroize;
171
172 use super::SecureBox;
173 use crate::test_utils::{PRIVATE_KEY_1, PRIVATE_KEY_2, Packed, Padded};
174
175 unsafe fn zero_out_secure_box<T>(secure_box: &mut SecureBox<T>)
181 where
182 T: Copy,
183 {
184 std::slice::from_raw_parts_mut::<MaybeUninit<u8>>(
188 secure_box.unsecure_mut() as *mut T as *mut MaybeUninit<u8>,
189 std::mem::size_of::<T>(),
190 )
191 .zeroize();
192 }
193
194 #[test]
195 fn test_secure_box() {
196 let key_1 = SecureBox::new(Box::new(PRIVATE_KEY_1));
197 let key_2 = SecureBox::new(Box::new(PRIVATE_KEY_2));
198 let key_3 = SecureBox::new(Box::new(PRIVATE_KEY_1));
199 assert!(key_1 == key_1);
200 assert!(key_1 != key_2);
201 assert!(key_2 != key_3);
202 assert!(key_1 == key_3);
203
204 let mut final_key = key_1.clone();
205 unsafe {
206 zero_out_secure_box(&mut final_key);
207 }
208 assert_eq!(final_key.unsecure().0, [0; 32]);
209 }
210
211 #[test]
212 fn test_repr_c_with_padding() {
213 assert_eq!(std::mem::size_of::<Padded>(), 4); let sec_a = SecureBox::new(Box::new(Padded { x: 1, y: 2 }));
216 let sec_b = SecureBox::new(Box::new(Padded { x: 1, y: 2 }));
217 assert_eq!(sec_a, sec_b);
218
219 let sec_c = SecureBox::new(Box::new(Padded { x: 1, y: 3 }));
220 assert_ne!(sec_a, sec_c);
221
222 let sec_d = SecureBox::new(Box::new(Padded { x: 2, y: 2 }));
223 assert_ne!(sec_a, sec_d);
224 }
225
226 #[test]
227 fn test_repr_c_packed() {
228 assert_eq!(std::mem::size_of::<Packed>(), 3);
229
230 let sec_a = SecureBox::new(Box::new(Packed { x: 42, y: 1000 }));
231 let sec_b = SecureBox::new(Box::new(Packed { x: 42, y: 1000 }));
232 let sec_c = SecureBox::new(Box::new(Packed { x: 42, y: 1001 }));
233 let sec_d = SecureBox::new(Box::new(Packed { x: 43, y: 1000 }));
234
235 assert_eq!(sec_a, sec_b);
236 assert_ne!(sec_a, sec_c);
237 assert_ne!(sec_a, sec_d);
238 }
239}