secure_types/
lib.rs

1#![doc = include_str!("../readme.md")]
2#![cfg_attr(feature = "no_os", no_os)]
3
4#[cfg(feature = "no_os")]
5extern crate alloc;
6
7pub mod array;
8pub mod string;
9pub mod vec;
10
11pub use array::SecureArray;
12pub use string::SecureString;
13pub use vec::{SecureBytes, SecureVec};
14
15use core::ptr::NonNull;
16pub use zeroize::Zeroize;
17
18#[cfg(feature = "use_os")]
19pub use memsec;
20#[cfg(feature = "use_os")]
21use memsec::Prot;
22#[cfg(all(feature = "use_os", unix))]
23use std::sync::Once;
24
25use thiserror::Error as ThisError;
26
27#[cfg(feature = "use_os")]
28#[derive(ThisError, Debug)]
29#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
30pub enum Error {
31   #[error("Failed to allocate secure memory")]
32   AllocationFailed,
33   #[error("Length cannot be zero")]
34   LengthCannotBeZero,
35   #[error("Allocated Ptr is null")]
36   NullAllocation,
37   #[error("Failed to lock memory")]
38   LockFailed,
39   #[error("Failed to unlock memory")]
40   UnlockFailed,
41   #[error("Source length does not match the fixed size of the destination array")]
42   LengthMismatch,
43}
44
45#[cfg(not(feature = "use_os"))]
46#[derive(Debug)]
47pub enum Error {
48   AlignmentFailed,
49   AllocationFailed,
50   NullAllocation,
51}
52
53
54#[cfg(all(feature = "use_os", unix))]
55static mut SUPPORTS_MEMFD_SECRET: bool = false;
56#[cfg(all(feature = "use_os", unix))]
57static SUPPORTS_MEMFD_SECRET_INIT: Once = Once::new();
58
59
60#[cfg(all(feature = "use_os", unix))]
61unsafe fn supports_memfd_secret_init() {
62   use libc::{SYS_memfd_secret, close, syscall};
63
64   let res = unsafe { syscall(SYS_memfd_secret as _, 0isize) };
65
66   if res >= 0 {
67      // memfd_secret is supported
68      unsafe { close(res as libc::c_int) };
69      unsafe { SUPPORTS_MEMFD_SECRET = true };
70   } else {
71      /*
72      let errno = unsafe { *libc::__errno_location() };
73      if errno == ENOSYS {
74         // not supported
75      } else {
76         // Other error treat as unsupported
77      }
78       */
79   }
80}
81
82/// Allocate memory
83///
84/// For `Windows` it always uses [memsec::malloc_sized]
85///
86/// For `Unix` it uses [memsec::memfd_secret_sized] if `memfd_secret` is supported
87///
88/// If the allocation fails it fallbacks to [memsec::malloc_sized]
89pub(crate) unsafe fn alloc<T>(size: usize) -> Result<NonNull<T>, Error> {
90   #[cfg(feature = "use_os")]
91   {
92      #[cfg(windows)]
93      unsafe {
94         let allocated_ptr = memsec::malloc_sized(size);
95         let non_null = allocated_ptr.ok_or(Error::AllocationFailed)?;
96         let ptr = non_null.as_ptr() as *mut T;
97         NonNull::new(ptr).ok_or(Error::NullAllocation)
98      }
99
100      #[cfg(unix)]
101      {
102         SUPPORTS_MEMFD_SECRET_INIT.call_once(|| unsafe { supports_memfd_secret_init() });
103
104         let supports_memfd_secret = unsafe { SUPPORTS_MEMFD_SECRET };
105
106         let ptr_opt = if supports_memfd_secret {
107            unsafe { memsec::memfd_secret_sized(size) }
108         } else {
109            None
110         };
111
112         if let Some(ptr) = ptr_opt {
113            NonNull::new(ptr.as_ptr() as *mut T).ok_or(Error::NullAllocation)
114         } else {
115            unsafe {
116               let allocated_ptr = memsec::malloc_sized(size);
117               let non_null = allocated_ptr.ok_or(Error::AllocationFailed)?;
118               let ptr = non_null.as_ptr() as *mut T;
119               NonNull::new(ptr).ok_or(Error::NullAllocation)
120            }
121         }
122      }
123   }
124
125   #[cfg(not(feature = "use_os"))]
126   {
127      let layout =
128         Layout::from_size_align(size, mem::align_of::<T>()).map_err(|_| Error::AlignmentFailed)?;
129      let ptr = unsafe { alloc::alloc(layout) as *mut T };
130      if ptr.is_null() {
131         return Err(Error::NullAllocation);
132      }
133      ptr
134   }
135}
136
137#[cfg(feature = "use_os")]
138pub(crate) fn free<T>(ptr: NonNull<T>) {
139   #[cfg(windows)]
140   unsafe {
141      memsec::free(ptr);
142   }
143
144   #[cfg(unix)]
145   {
146      let supports_memfd_secret = unsafe { SUPPORTS_MEMFD_SECRET };
147      if supports_memfd_secret {
148         unsafe { memsec::free_memfd_secret(ptr) };
149      } else {
150         unsafe { memsec::free(ptr) };
151      }
152   }
153}
154
155#[cfg(feature = "use_os")]
156pub(crate) fn mprotect<T>(ptr: NonNull<T>, prot: Prot::Ty) -> bool {
157   let success = unsafe { memsec::mprotect(ptr, prot) };
158   #[cfg(test)]
159   {
160      if !success {
161         eprintln!("mprotect failed");
162      }
163   }
164   success
165}
166
167
168#[cfg(test)]
169mod tests {
170
171   #[cfg(unix)]
172   #[test]
173   fn test_supports_memfd_secret() {
174      use super::*;
175
176      SUPPORTS_MEMFD_SECRET_INIT.call_once(|| unsafe { supports_memfd_secret_init() });
177
178      let supports = unsafe { SUPPORTS_MEMFD_SECRET };
179
180      if supports {
181         print!("memfd_secret is supported");
182         let size = 1 * size_of::<u8>();
183         let ptr = unsafe { memsec::memfd_secret_sized(size) };
184         assert!(ptr.is_some());
185      } else {
186         print!("memfd_secret is not supported");
187      }
188   }
189
190   #[cfg(feature = "serde")]
191   #[test]
192   fn test_array_and_secure_vec_serde_compatibility() {
193      use super::*;
194      let exposed_array: &mut [u8; 3] = &mut [1, 2, 3];
195      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_array).unwrap();
196      let vec: SecureVec<u8> = array.clone().into();
197
198      let array_json_string = serde_json::to_string(&array).unwrap();
199      let array_json_bytes = serde_json::to_vec(&array).unwrap();
200      let vec_json_string = serde_json::to_string(&vec).unwrap();
201      let vec_json_bytes = serde_json::to_vec(&vec).unwrap();
202
203      assert_eq!(array_json_string, vec_json_string);
204      assert_eq!(array_json_bytes, vec_json_bytes);
205
206      let deserialized_array_from_string: SecureArray<u8, 3> =
207         serde_json::from_str(&array_json_string).unwrap();
208
209      let deserialized_array_from_bytes: SecureArray<u8, 3> =
210         serde_json::from_slice(&array_json_bytes).unwrap();
211
212      let deserialized_vec_from_string: SecureVec<u8> =
213         serde_json::from_str(&vec_json_string).unwrap();
214
215      let deserialized_vec_from_bytes: SecureVec<u8> =
216         serde_json::from_slice(&vec_json_bytes).unwrap();
217
218      deserialized_array_from_string.unlock(|slice| {
219         deserialized_vec_from_string.unlock_slice(|slice2| {
220            assert_eq!(slice, slice2);
221         });
222      });
223
224      deserialized_array_from_bytes.unlock(|slice| {
225         deserialized_vec_from_bytes.unlock_slice(|slice2| {
226            assert_eq!(slice, slice2);
227         });
228      });
229   }
230}