secure_types/
lib.rs

1#![doc = include_str!("../readme.md")]
2#![cfg_attr(feature = "no_os", no_os)]
3
4#[cfg(feature = "no_os")]
5extern crate alloc;
6
7pub mod array;
8pub mod string;
9pub mod vec;
10
11pub use array::SecureArray;
12pub use string::SecureString;
13pub use vec::{SecureBytes, SecureVec};
14
15use core::ptr::NonNull;
16pub use zeroize::Zeroize;
17
18#[cfg(feature = "use_os")]
19pub use memsec;
20#[cfg(feature = "use_os")]
21use memsec::Prot;
22#[cfg(all(feature = "use_os", unix))]
23use std::sync::Once;
24
25use thiserror::Error as ThisError;
26
27#[cfg(feature = "use_os")]
28#[derive(ThisError, Debug)]
29#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
30pub enum Error {
31   #[error("Failed to allocate secure memory")]
32   AllocationFailed,
33   #[error("Length cannot be zero")]
34   LengthCannotBeZero,
35   #[error("Allocated Ptr is null")]
36   NullAllocation,
37   #[error("Failed to lock memory")]
38   LockFailed,
39   #[error("Failed to unlock memory")]
40   UnlockFailed,
41   #[error("Source length does not match the fixed size of the destination array")]
42   LengthMismatch,
43}
44
45#[cfg(not(feature = "use_os"))]
46#[derive(Debug)]
47pub enum Error {
48   AlignmentFailed,
49   AllocationFailed,
50   NullAllocation,
51}
52
53#[cfg(all(feature = "use_os", unix))]
54static mut SUPPORTS_MEMFD_SECRET: bool = false;
55#[cfg(all(feature = "use_os", unix))]
56static SUPPORTS_MEMFD_SECRET_INIT: Once = Once::new();
57#[cfg(all(feature = "use_os", unix))]
58const ALLOC_TAG_MALLOC: usize = 0xDEAD_BEEF;
59#[cfg(all(feature = "use_os", unix))]
60const ALLOC_TAG_MEMFD: usize = 0x5EC0_0000;
61
62/// Calculates the offset needed to store a usize header while maintaining
63/// the alignment requirements of T.
64#[cfg(all(feature = "use_os", unix))]
65const fn get_header_offset<T>() -> usize {
66   let header_size = core::mem::size_of::<usize>();
67   let align = core::mem::align_of::<T>();
68
69   // If T needs stronger alignment than usize, we must pad more.
70   // Otherwise, sizeof(usize) is sufficient.
71   if align > header_size {
72      align
73   } else {
74      header_size
75   }
76}
77
78#[cfg(all(feature = "use_os", unix))]
79unsafe fn supports_memfd_secret_init() {
80   use libc::{SYS_memfd_secret, close, syscall};
81
82   let res = unsafe { syscall(SYS_memfd_secret as _, 0isize) };
83
84   if res >= 0 {
85      // memfd_secret is supported
86      unsafe { close(res as libc::c_int) };
87      unsafe { SUPPORTS_MEMFD_SECRET = true };
88   } else {
89      /*
90      let errno = unsafe { *libc::__errno_location() };
91      if errno == ENOSYS {
92         // not supported
93      } else {
94         // Other error
95      }
96       */
97   }
98}
99
100/// Allocate memory
101///
102/// For `Windows` it always uses [memsec::malloc_sized]
103///
104/// For `Unix` it uses [memsec::memfd_secret_sized] if `memfd_secret` is supported
105///
106/// If the allocation fails it fallbacks to [memsec::malloc_sized]
107pub(crate) unsafe fn alloc<T>(size: usize) -> Result<NonNull<T>, Error> {
108   #[cfg(feature = "use_os")]
109   {
110      #[cfg(windows)]
111      unsafe {
112         let allocated_ptr = memsec::malloc_sized(size);
113         let non_null = allocated_ptr.ok_or(Error::AllocationFailed)?;
114         let ptr = non_null.as_ptr() as *mut T;
115         NonNull::new(ptr).ok_or(Error::NullAllocation)
116      }
117
118      #[cfg(unix)]
119      {
120         SUPPORTS_MEMFD_SECRET_INIT.call_once(|| unsafe { supports_memfd_secret_init() });
121         let supports_memfd_secret = unsafe { SUPPORTS_MEMFD_SECRET };
122
123         let header_offset = get_header_offset::<T>();
124
125         // Calculate alignment requirement
126         let align_req = core::mem::align_of::<usize>().max(core::mem::align_of::<T>());
127
128         // Calculate raw size (Header + Data)
129         let raw_size = size
130            .checked_add(header_offset)
131            .ok_or(Error::AllocationFailed)?;
132
133         // Calculate padded size to satisfy alignment
134         let remainder = raw_size % align_req;
135         let alloc_size = if remainder == 0 {
136            raw_size
137         } else {
138            raw_size
139               .checked_add(align_req - remainder)
140               .ok_or(Error::AllocationFailed)?
141         };
142
143         let ptr_opt = if supports_memfd_secret {
144            unsafe { memsec::memfd_secret_sized(alloc_size) }
145         } else {
146            None
147         };
148
149         if let Some(raw_ptr_nonnull) = ptr_opt {
150            let raw_ptr = raw_ptr_nonnull.as_ptr() as *mut u8;
151
152            // Write the MEMFD tag
153            unsafe { *(raw_ptr as *mut usize) = ALLOC_TAG_MEMFD };
154
155            let user_ptr = unsafe { raw_ptr.add(header_offset) as *mut T };
156            return NonNull::new(user_ptr).ok_or(Error::NullAllocation);
157         }
158
159         unsafe {
160            let allocated_ptr = memsec::malloc_sized(alloc_size);
161            let non_null = allocated_ptr.ok_or(Error::AllocationFailed)?;
162
163            let raw_ptr = non_null.as_ptr() as *mut u8;
164
165            // Write the MALLOC tag
166            *(raw_ptr as *mut usize) = ALLOC_TAG_MALLOC;
167
168            let user_ptr = raw_ptr.add(header_offset) as *mut T;
169            NonNull::new(user_ptr).ok_or(Error::NullAllocation)
170         }
171      }
172   }
173
174   #[cfg(not(feature = "use_os"))]
175   {
176      use core::alloc::Layout;
177      let layout =
178         Layout::from_size_align(size, mem::align_of::<T>()).map_err(|_| Error::AlignmentFailed)?;
179      let ptr = unsafe { alloc::alloc(layout) as *mut T };
180      if ptr.is_null() {
181         return Err(Error::NullAllocation);
182      }
183      unsafe { NonNull::new_unchecked(ptr) }
184   }
185}
186
187#[cfg(feature = "use_os")]
188pub(crate) fn free<T>(ptr: NonNull<T>) {
189   #[cfg(windows)]
190   unsafe {
191      memsec::free(ptr);
192   }
193
194   #[cfg(unix)]
195   {
196      let header_offset = get_header_offset::<T>();
197
198      unsafe {
199         let user_ptr = ptr.as_ptr() as *mut u8;
200         let raw_ptr = user_ptr.sub(header_offset);
201
202         // Reconstruct the NonNull pointer to the START of the allocation (header)
203         let non_null_raw = NonNull::new_unchecked(raw_ptr);
204
205         // Read the tag
206         let tag = *(raw_ptr as *const usize);
207
208         match tag {
209            ALLOC_TAG_MEMFD => {
210               memsec::free_memfd_secret(non_null_raw);
211            }
212            ALLOC_TAG_MALLOC => {
213               memsec::free(non_null_raw);
214            }
215            _ => {
216               // SHOULD NOT HAPPEN
217               // Tag mismatch: Double free or corruption.
218               #[cfg(debug_assertions)]
219               panic!(
220                  "SecureAllocator: Corrupt header tag found: {:x}",
221                  tag
222               );
223            }
224         }
225      }
226   }
227}
228
229#[cfg(feature = "use_os")]
230pub(crate) fn mprotect<T>(ptr: NonNull<T>, prot: Prot::Ty) -> bool {
231   #[cfg(unix)]
232   {
233      // We need to protect the whole block, including the header.
234      let header_offset = get_header_offset::<T>();
235      unsafe {
236         let raw_ptr = (ptr.as_ptr() as *mut u8).sub(header_offset);
237         let raw_non_null = NonNull::new_unchecked(raw_ptr as *mut T);
238
239         memsec::mprotect(raw_non_null, prot)
240      }
241   }
242   #[cfg(windows)]
243   {
244      unsafe { memsec::mprotect(ptr, prot) }
245   }
246}
247
248#[cfg(test)]
249mod tests {
250
251   #[cfg(unix)]
252   #[test]
253   fn test_supports_memfd_secret() {
254      use super::*;
255
256      SUPPORTS_MEMFD_SECRET_INIT.call_once(|| unsafe { supports_memfd_secret_init() });
257
258      let supports = unsafe { SUPPORTS_MEMFD_SECRET };
259
260      if supports {
261         print!("memfd_secret is supported");
262         let size = 1 * size_of::<u8>();
263         let ptr = unsafe { memsec::memfd_secret_sized(size) };
264         assert!(ptr.is_some());
265      } else {
266         print!("memfd_secret is not supported");
267      }
268   }
269
270   #[cfg(feature = "serde")]
271   #[test]
272   fn test_array_and_secure_vec_serde_compatibility() {
273      use super::*;
274      let exposed_array: &mut [u8; 3] = &mut [1, 2, 3];
275      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_array).unwrap();
276      let vec: SecureVec<u8> = array.clone().into();
277
278      let array_json_string = serde_json::to_string(&array).unwrap();
279      let array_json_bytes = serde_json::to_vec(&array).unwrap();
280      let vec_json_string = serde_json::to_string(&vec).unwrap();
281      let vec_json_bytes = serde_json::to_vec(&vec).unwrap();
282
283      assert_eq!(array_json_string, vec_json_string);
284      assert_eq!(array_json_bytes, vec_json_bytes);
285
286      let deserialized_array_from_string: SecureArray<u8, 3> =
287         serde_json::from_str(&array_json_string).unwrap();
288
289      let deserialized_array_from_bytes: SecureArray<u8, 3> =
290         serde_json::from_slice(&array_json_bytes).unwrap();
291
292      let deserialized_vec_from_string: SecureVec<u8> =
293         serde_json::from_str(&vec_json_string).unwrap();
294
295      let deserialized_vec_from_bytes: SecureVec<u8> =
296         serde_json::from_slice(&vec_json_bytes).unwrap();
297
298      deserialized_array_from_string.unlock(|slice| {
299         deserialized_vec_from_string.unlock_slice(|slice2| {
300            assert_eq!(slice, slice2);
301         });
302      });
303
304      deserialized_array_from_bytes.unlock(|slice| {
305         deserialized_vec_from_bytes.unlock_slice(|slice2| {
306            assert_eq!(slice, slice2);
307         });
308      });
309   }
310}