secure_types/
lib.rs

1#![doc = include_str!("../readme.md")]
2#![cfg_attr(feature = "no_os", no_os)]
3
4#[cfg(feature = "no_os")]
5extern crate alloc;
6
7pub mod array;
8pub mod string;
9pub mod vec;
10
11pub use array::SecureArray;
12pub use string::SecureString;
13pub use vec::{SecureBytes, SecureVec};
14
15use core::ptr::NonNull;
16pub use zeroize::Zeroize;
17
18#[cfg(feature = "use_os")]
19pub use memsec;
20#[cfg(feature = "use_os")]
21use memsec::Prot;
22#[cfg(feature = "use_os")]
23use std::sync::Once;
24
25use thiserror::Error as ThisError;
26
27#[cfg(feature = "use_os")]
28#[derive(ThisError, Debug)]
29#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
30pub enum Error {
31   #[error("Failed to allocate secure memory")]
32   AllocationFailed,
33   #[error("Length cannot be zero")]
34   LengthCannotBeZero,
35   #[error("Allocated Ptr is null")]
36   NullAllocation,
37   #[error("CryptProtectMemory failed")]
38   CryptProtectMemoryFailed,
39   #[error("CryptUnprotectMemory failed")]
40   CryptUnprotectMemoryFailed,
41   #[error("Failed to lock memory")]
42   LockFailed,
43   #[error("Failed to unlock memory")]
44   UnlockFailed,
45   #[error("Source length does not match the fixed size of the destination array")]
46   LengthMismatch,
47}
48
49#[cfg(not(feature = "use_os"))]
50#[derive(Debug)]
51pub enum Error {
52   AlignmentFailed,
53   AllocationFailed,
54   NullAllocation,
55}
56
57#[cfg(all(feature = "use_os", test, windows))]
58use windows_sys::Win32::Foundation::GetLastError;
59#[cfg(all(feature = "use_os", windows))]
60use windows_sys::Win32::Security::Cryptography::{
61   CRYPTPROTECTMEMORY_BLOCK_SIZE, CRYPTPROTECTMEMORY_SAME_PROCESS, CryptProtectMemory,
62   CryptUnprotectMemory,
63};
64#[cfg(all(feature = "use_os", windows))]
65use windows_sys::Win32::System::SystemInformation::GetSystemInfo;
66
67static PAGE_SIZE_INIT: Once = Once::new();
68static mut PAGE_SIZE: usize = 0;
69
70#[cfg(all(feature = "use_os", unix))]
71static mut SUPPORTS_MEMFD_SECRET: bool = false;
72#[cfg(all(feature = "use_os", unix))]
73static SUPPORTS_MEMFD_SECRET_INIT: Once = Once::new();
74
75#[cfg(feature = "use_os")]
76/// Returns the page size depending on the OS
77unsafe fn page_size_init() {
78   #[cfg(unix)]
79   unsafe {
80      PAGE_SIZE = libc::sysconf(libc::_SC_PAGESIZE) as usize;
81   }
82
83   #[cfg(windows)]
84   {
85      let mut si = core::mem::MaybeUninit::uninit();
86      unsafe {
87         GetSystemInfo(si.as_mut_ptr());
88         PAGE_SIZE = (*si.as_ptr()).dwPageSize as usize;
89      }
90   }
91}
92
93#[cfg(feature = "use_os")]
94/// Returns the page aligned size of a given size
95pub(crate) unsafe fn page_aligned_size(size: usize) -> usize {
96   PAGE_SIZE_INIT.call_once(|| unsafe { page_size_init() });
97   unsafe { (size + PAGE_SIZE - 1) & !(PAGE_SIZE - 1) }
98}
99
100#[cfg(all(feature = "use_os", unix))]
101unsafe fn supports_memfd_secret_init() {
102   use libc::{SYS_memfd_secret, close, syscall};
103
104   let res = unsafe { syscall(SYS_memfd_secret as _, 0isize) };
105
106   if res >= 0 {
107      // memfd_secret is supported
108      unsafe { close(res as libc::c_int) };
109      unsafe { SUPPORTS_MEMFD_SECRET = true };
110   } else {
111      /*
112      let errno = unsafe { *libc::__errno_location() };
113      if errno == ENOSYS {
114         // not supported
115      } else {
116         // Other error treat as unsupported
117      }
118       */
119   }
120}
121
122/// Allocate memory
123///
124/// Size is page aligned if the target is an OS
125///
126/// For `Windows` it always uses [memsec::malloc_sized]
127///
128/// For `Unix` it uses [memsec::memfd_secret_sized] if `memfd_secret` is supported
129///
130/// If the allocation fails it fallbacks to [memsec::malloc_sized]
131pub(crate) unsafe fn alloc_aligned<T>(size: usize) -> Result<NonNull<T>, Error> {
132   #[cfg(feature = "use_os")]
133   {
134      #[cfg(windows)]
135      unsafe {
136         let aligned_size = page_aligned_size(size);
137         let allocated_ptr = memsec::malloc_sized(aligned_size);
138         let non_null = allocated_ptr.ok_or(Error::AllocationFailed)?;
139         let ptr = non_null.as_ptr() as *mut T;
140         NonNull::new(ptr).ok_or(Error::NullAllocation)
141      }
142
143      #[cfg(unix)]
144      {
145         SUPPORTS_MEMFD_SECRET_INIT.call_once(|| unsafe { supports_memfd_secret_init() });
146
147         let aligned_size = unsafe { page_aligned_size(size) };
148         let supports_memfd_secret = unsafe { SUPPORTS_MEMFD_SECRET };
149
150         let ptr_opt = if supports_memfd_secret {
151            unsafe { memsec::memfd_secret_sized(aligned_size) }
152         } else {
153            None
154         };
155
156         if let Some(ptr) = ptr_opt {
157            NonNull::new(ptr.as_ptr() as *mut T).ok_or(Error::NullAllocation)
158         } else {
159            unsafe {
160               let allocated_ptr = memsec::malloc_sized(aligned_size);
161               let non_null = allocated_ptr.ok_or(Error::AllocationFailed)?;
162               let ptr = non_null.as_ptr() as *mut T;
163               NonNull::new(ptr).ok_or(Error::NullAllocation)
164            }
165         }
166      }
167   }
168
169   #[cfg(not(feature = "use_os"))]
170   {
171      let layout =
172         Layout::from_size_align(size, mem::align_of::<T>()).map_err(|_| Error::AlignmentFailed)?;
173      let ptr = unsafe { alloc::alloc(layout) as *mut T };
174      if ptr.is_null() {
175         return Err(Error::NullAllocation);
176      }
177      ptr
178   }
179}
180
181#[cfg(feature = "use_os")]
182pub(crate) fn free<T>(ptr: NonNull<T>) {
183   #[cfg(windows)]
184   unsafe {
185      memsec::free(ptr);
186   }
187
188   #[cfg(unix)]
189   {
190      let supports_memfd_secret = unsafe { SUPPORTS_MEMFD_SECRET };
191      if supports_memfd_secret {
192         unsafe { memsec::free_memfd_secret(ptr) };
193      } else {
194         unsafe { memsec::free(ptr) };
195      }
196   }
197}
198
199#[cfg(feature = "use_os")]
200pub(crate) fn mprotect<T>(ptr: NonNull<T>, prot: Prot::Ty) -> bool {
201   let success = unsafe { memsec::mprotect(ptr, prot) };
202   #[cfg(test)]
203   {
204      if !success {
205         eprintln!("mprotect failed");
206      }
207   }
208   success
209}
210
211#[cfg(all(feature = "use_os", windows))]
212pub(crate) fn crypt_protect_memory(ptr: *mut u8, aligned_size: usize) -> bool {
213   if aligned_size == 0 {
214      return true; // Nothing to encrypt
215   }
216
217   if aligned_size % (CRYPTPROTECTMEMORY_BLOCK_SIZE as usize) != 0 {
218      // not a multiple of CRYPTPROTECTMEMORY_BLOCK_SIZE
219      return false;
220   }
221
222   if aligned_size > u32::MAX as usize {
223      return false;
224   }
225
226   let result = unsafe {
227      CryptProtectMemory(
228         ptr as *mut core::ffi::c_void,
229         aligned_size as u32,
230         CRYPTPROTECTMEMORY_SAME_PROCESS,
231      )
232   };
233
234   if result == 0 {
235      #[cfg(test)]
236      {
237         let error_code = unsafe { GetLastError() };
238         eprintln!(
239            "CryptProtectMemory failed with error code: {}",
240            error_code
241         );
242      }
243      false
244   } else {
245      true
246   }
247}
248
249#[cfg(all(feature = "use_os", windows))]
250pub(crate) fn crypt_unprotect_memory(ptr: *mut u8, size_in_bytes: usize) -> bool {
251   if size_in_bytes == 0 {
252      return true;
253   }
254
255   if size_in_bytes % (CRYPTPROTECTMEMORY_BLOCK_SIZE as usize) != 0 {
256      return false;
257   }
258
259   if size_in_bytes > u32::MAX as usize {
260      return false;
261   }
262
263   let result = unsafe {
264      CryptUnprotectMemory(
265         ptr as *mut core::ffi::c_void,
266         size_in_bytes as u32,
267         CRYPTPROTECTMEMORY_SAME_PROCESS,
268      )
269   };
270
271   if result == 0 {
272      #[cfg(test)]
273      {
274         let error_code = unsafe { GetLastError() };
275         eprintln!(
276            "CryptUnprotectMemory failed with error code: {}",
277            error_code
278         );
279      }
280      false
281   } else {
282      true
283   }
284}
285
286#[cfg(test)]
287mod tests {
288
289   #[cfg(unix)]
290   #[test]
291   fn test_supports_memfd_secret() {
292      use super::*;
293
294      SUPPORTS_MEMFD_SECRET_INIT.call_once(|| unsafe { supports_memfd_secret_init() });
295
296      let supports = unsafe { SUPPORTS_MEMFD_SECRET };
297
298      if supports {
299         print!("memfd_secret is supported");
300         let size = 1 * size_of::<u8>();
301         let aligned = unsafe { page_aligned_size(size) };
302         let ptr = unsafe { memsec::memfd_secret_sized(aligned) };
303         assert!(ptr.is_some());
304      } else {
305         print!("memfd_secret is not supported");
306      }
307   }
308
309   #[cfg(feature = "serde")]
310   #[test]
311   fn test_array_and_secure_vec_serde_compatibility() {
312      use super::*;
313      let exposed_array: &mut [u8; 3] = &mut [1, 2, 3];
314      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_array).unwrap();
315      let vec: SecureVec<u8> = array.clone().into();
316
317      let array_json_string = serde_json::to_string(&array).unwrap();
318      let array_json_bytes = serde_json::to_vec(&array).unwrap();
319      let vec_json_string = serde_json::to_string(&vec).unwrap();
320      let vec_json_bytes = serde_json::to_vec(&vec).unwrap();
321
322      assert_eq!(array_json_string, vec_json_string);
323      assert_eq!(array_json_bytes, vec_json_bytes);
324
325      let deserialized_array_from_string: SecureArray<u8, 3> =
326         serde_json::from_str(&array_json_string).unwrap();
327
328      let deserialized_array_from_bytes: SecureArray<u8, 3> =
329         serde_json::from_slice(&array_json_bytes).unwrap();
330
331      let deserialized_vec_from_string: SecureVec<u8> =
332         serde_json::from_str(&vec_json_string).unwrap();
333
334      let deserialized_vec_from_bytes: SecureVec<u8> =
335         serde_json::from_slice(&vec_json_bytes).unwrap();
336
337      deserialized_array_from_string.unlock(|slice| {
338         deserialized_vec_from_string.unlock_slice(|slice2| {
339            assert_eq!(slice, slice2);
340         });
341      });
342
343      deserialized_array_from_bytes.unlock(|slice| {
344         deserialized_vec_from_bytes.unlock_slice(|slice2| {
345            assert_eq!(slice, slice2);
346         });
347      });
348   }
349}