secure_types/
array.rs

1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::Error;
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "std")]
9use super::page_size;
10#[cfg(feature = "std")]
11use memsec::Prot;
12
13/// A fixed-size array allocated in a secure memory region.
14///
15/// `SecureArray` provides the same core security guarantees as the other types in this
16/// crate, including zeroization on drop and optional memory locking/encryption when
17// compiled with the `std` feature.
18///
19/// It is ideal for secrets of a known, fixed length.
20///
21/// # Program Termination
22///
23/// Direct indexing (e.g., `array[0]`) on a locked array will cause the operating system
24/// to terminate the process with an access violation error. Always use the provided
25/// scope methods (`unlocked_scope`, `unlocked_mut_scope`) for safe access.
26///
27/// # Examples
28///
29/// ```
30/// use secure_types::SecureArray;
31///
32/// let key_data = [1u8; 32];
33/// let secure_key: SecureArray<u8, 32> = SecureArray::new(key_data).unwrap();
34///
35/// secure_key.unlocked_scope(|unlocked_slice| {
36///     assert_eq!(unlocked_slice.len(), 32);
37///     assert_eq!(unlocked_slice[0], 1);
38/// });
39/// ```
40pub struct SecureArray<T, const LENGTH: usize>
41where
42   T: Zeroize,
43{
44   ptr: NonNull<T>,
45   _marker: PhantomData<T>,
46}
47
48unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
49unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
50
51impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
52where
53   T: Zeroize,
54{
55   /// Creates an empty (but allocated) SecureArray.
56   /// The memory is allocated but not initialized, and it's the caller's responsibility to fill it.
57   pub fn empty() -> Result<Self, Error> {
58      let size = LENGTH * mem::size_of::<T>();
59      if size == 0 {
60         // Cannot create a zero-sized secure array
61         return Err(Error::AllocationFailed);
62      }
63
64      #[cfg(feature = "std")]
65      let new_ptr = {
66         let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
67         let allocated_ptr = unsafe { memsec::malloc_sized(aligned_size) };
68         allocated_ptr.ok_or(Error::AllocationFailed)?.as_ptr() as *mut T
69      };
70
71      #[cfg(not(feature = "std"))]
72      let new_ptr = {
73         let layout = Layout::from_size_align(size, mem::align_of::<T>())
74            .map_err(|_| Error::AllocationFailed)?;
75         let ptr = unsafe { alloc::alloc(layout) as *mut T };
76         if ptr.is_null() {
77            return Err(Error::AllocationFailed);
78         }
79         ptr
80      };
81
82      let non_null = NonNull::new(new_ptr).ok_or(Error::NullAllocation)?;
83
84      let secure_array = SecureArray {
85         ptr: non_null,
86         _marker: PhantomData,
87      };
88
89      let (encrypted, locked) = secure_array.lock_memory();
90
91      #[cfg(feature = "std")]
92      if !locked {
93         return Err(Error::LockFailed);
94      }
95
96      #[cfg(feature = "std")]
97      if !encrypted {
98         return Err(Error::CryptProtectMemoryFailed);
99      }
100
101      Ok(secure_array)
102   }
103
104   /// Creates a new SecureArray from a given array.
105   pub fn new(mut content: [T; LENGTH]) -> Result<Self, Error> {
106      let secure_array = Self::empty()?;
107
108      secure_array.unlock_memory();
109
110      unsafe {
111         // Copy the data from the source array into the secure memory region
112         core::ptr::copy_nonoverlapping(
113            content.as_ptr(),
114            secure_array.ptr.as_ptr(),
115            LENGTH,
116         );
117      }
118
119      content.zeroize();
120
121      let (encrypted, locked) = secure_array.lock_memory();
122
123      #[cfg(feature = "std")]
124      if !locked {
125         return Err(Error::LockFailed);
126      }
127
128      #[cfg(feature = "std")]
129      if !encrypted {
130         return Err(Error::CryptProtectMemoryFailed);
131      }
132
133      Ok(secure_array)
134   }
135
136   pub fn len(&self) -> usize {
137      LENGTH
138   }
139
140   pub fn as_ptr(&self) -> *const T {
141      self.ptr.as_ptr()
142   }
143
144   pub fn as_mut_ptr(&mut self) -> *mut u8 {
145      self.ptr.as_ptr() as *mut u8
146   }
147
148   #[allow(dead_code)]
149   fn allocated_byte_size(&self) -> usize {
150      let size = self.len() * mem::size_of::<T>();
151      #[cfg(feature = "std")]
152      {
153         (size + page_size() - 1) & !(page_size() - 1)
154      }
155      #[cfg(not(feature = "std"))]
156      {
157         size // No page alignment in no_std
158      }
159   }
160
161   #[cfg(all(feature = "std", windows))]
162   fn encypt_memory(&self) -> bool {
163      let ptr = self.as_ptr() as *mut u8;
164      super::crypt_protect_memory(ptr, self.allocated_byte_size())
165   }
166
167   #[cfg(all(feature = "std", windows))]
168   fn decrypt_memory(&self) -> bool {
169      let ptr = self.as_ptr() as *mut u8;
170      super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
171   }
172
173   pub(crate) fn lock_memory(&self) -> (bool, bool) {
174      #[cfg(feature = "std")]
175      {
176         #[cfg(windows)]
177         {
178            let encrypt_ok = self.encypt_memory();
179            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
180            (encrypt_ok, mprotect_ok)
181         }
182         #[cfg(unix)]
183         {
184            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
185            (true, mprotect_ok)
186         }
187      }
188      #[cfg(not(feature = "std"))]
189      {
190         (true, true) // No-op: always "succeeds"
191      }
192   }
193
194   pub(crate) fn unlock_memory(&self) -> (bool, bool) {
195      #[cfg(feature = "std")]
196      {
197         #[cfg(windows)]
198         {
199            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
200            if !mprotect_ok {
201               return (false, false);
202            }
203            let decrypt_ok = self.decrypt_memory();
204            (decrypt_ok, mprotect_ok)
205         }
206         #[cfg(unix)]
207         {
208            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
209            (true, mprotect_ok)
210         }
211      }
212
213      #[cfg(not(feature = "std"))]
214      {
215         (true, true) // No-op: always "succeeds"
216      }
217   }
218
219   /// Provides scoped, immutable access to the array's data as a slice.
220   pub fn unlocked_scope<F, R>(&self, f: F) -> R
221   where
222      F: FnOnce(&[T]) -> R,
223   {
224      self.unlock_memory();
225      let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
226      let result = f(slice);
227      self.lock_memory();
228      result
229   }
230
231   /// Provides scoped, mutable access to the array's data as a mutable slice.
232   pub fn unlocked_mut_scope<F, R>(&mut self, f: F) -> R
233   where
234      F: FnOnce(&mut [T]) -> R,
235   {
236      self.unlock_memory();
237      let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
238      let result = f(slice);
239      self.lock_memory();
240      result
241   }
242
243   /// Securely erases the contents of the array by zeroizing the memory.
244   pub fn erase(&mut self) {
245      self.unlocked_mut_scope(|slice| {
246         for element in slice.iter_mut() {
247            element.zeroize();
248         }
249      });
250   }
251}
252
253impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
254   type Output = T;
255   fn index(&self, index: usize) -> &Self::Output {
256      assert!(index < self.len(), "Index out of bounds");
257      unsafe {
258         let ptr = self.ptr.as_ptr().add(index);
259         let reference = &*ptr;
260         reference
261      }
262   }
263}
264
265impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
266   fn index_mut(&mut self, index: usize) -> &mut Self::Output {
267      assert!(index < self.len(), "Index out of bounds");
268      unsafe {
269         let ptr = self.ptr.as_ptr().add(index);
270         let reference = &mut *ptr;
271         reference
272      }
273   }
274}
275
276impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
277   fn drop(&mut self) {
278      self.erase();
279      self.unlock_memory();
280
281      let size = LENGTH * mem::size_of::<T>();
282      if size == 0 {
283         return;
284      }
285
286      unsafe {
287         #[cfg(feature = "std")]
288         {
289            memsec::free(self.ptr);
290         }
291         #[cfg(not(feature = "std"))]
292         {
293            // Recreate the layout to deallocate correctly
294            let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
295            dealloc(self.ptr.as_ptr() as *mut u8, layout);
296         }
297      }
298   }
299}
300
301impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
302   fn clone(&self) -> Self {
303      let mut new_array = Self::empty().unwrap();
304      self.unlocked_scope(|src_slice| {
305         new_array.unlocked_mut_scope(|dest_slice| {
306            dest_slice.clone_from_slice(src_slice);
307         });
308      });
309      new_array
310   }
311}
312
313impl<T, const LENGTH: usize> TryFrom<[T; LENGTH]> for SecureArray<T, LENGTH>
314where
315   T: Zeroize,
316{
317   type Error = Error;
318   fn try_from(s: [T; LENGTH]) -> Result<Self, Error> {
319      Self::new(s)
320   }
321}
322
323#[cfg(all(test, feature = "std"))]
324mod tests {
325   use super::*;
326   use std::process::{Command, Stdio};
327   use std::sync::{Arc, Mutex};
328
329   #[test]
330   fn test_creation() {
331      let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
332      assert_eq!(array.len(), 3);
333      array.unlocked_scope(|slice| {
334         assert_eq!(slice, &[1, 2, 3]);
335      });
336   }
337
338   #[test]
339   fn test_erase() {
340      let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
341      array.erase();
342      array.unlocked_scope(|slice| {
343         assert_eq!(slice, &[0u8; 3]);
344      });
345   }
346
347   #[test]
348   fn test_clone() {
349      let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
350      array1.unlocked_mut_scope(|slice| {
351         slice[0] = 1;
352         slice[1] = 2;
353         slice[2] = 3;
354      });
355
356      let array2 = array1.clone();
357
358      array2.unlocked_scope(|slice| {
359         assert_eq!(slice, &[1, 2, 3]);
360      });
361
362      array1.unlocked_scope(|slice| {
363         assert_eq!(slice, &[1, 2, 3]);
364      });
365   }
366
367   #[test]
368   fn test_thread_safety() {
369      let array = SecureArray::new([1, 2, 3]).unwrap();
370      let arc_array = Arc::new(Mutex::new(array));
371      let mut handles = Vec::new();
372
373      for _ in 0..5u8 {
374         let array_clone = Arc::clone(&arc_array);
375         let handle = std::thread::spawn(move || {
376            let mut guard = array_clone.lock().unwrap();
377            guard.unlocked_mut_scope(|slice| {
378               slice[0] += 1;
379            });
380         });
381         handles.push(handle);
382      }
383
384      for handle in handles {
385         handle.join().unwrap();
386      }
387
388      let final_array = arc_array.lock().unwrap();
389      final_array.unlocked_scope(|slice| {
390         assert_eq!(slice[0], 6);
391      });
392   }
393
394   #[test]
395   fn test_index_should_fail_when_locked() {
396      let arg = "CRASH_TEST_ARRAY_LOCKED";
397
398      if std::env::args().any(|a| a == arg) {
399         let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
400         let _value = core::hint::black_box(array[0]);
401
402         std::process::exit(1);
403      }
404
405      let child = Command::new(std::env::current_exe().unwrap())
406         .arg("array::tests::test_index_should_fail_when_locked")
407         .arg(arg)
408         .arg("--nocapture")
409         .stdout(Stdio::piped())
410         .stderr(Stdio::piped())
411         .spawn()
412         .expect("Failed to spawn child process");
413
414      let output = child.wait_with_output().expect("Failed to wait on child");
415      let status = output.status;
416
417      assert!(
418         !status.success(),
419         "Process exited successfully with code {:?}, but it should have crashed.",
420         status.code()
421      );
422
423      #[cfg(unix)]
424      {
425         use std::os::unix::process::ExitStatusExt;
426         let signal = status
427            .signal()
428            .expect("Process was not terminated by a signal on Unix.");
429         assert!(
430            signal == libc::SIGSEGV || signal == libc::SIGBUS,
431            "Process terminated with unexpected signal: {}",
432            signal
433         );
434         println!(
435            "Test passed: Process correctly terminated with signal {}.",
436            signal
437         );
438      }
439
440      #[cfg(windows)]
441      {
442         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
443         assert_eq!(
444            status.code(),
445            Some(STATUS_ACCESS_VIOLATION),
446            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
447            status.code()
448         );
449         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
450      }
451   }
452
453   #[test]
454   fn test_mutable_access_in_scope() {
455      let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
456
457      array.unlocked_mut_scope(|slice| {
458         slice[1] = 100;
459      });
460
461      array.unlocked_scope(|slice| {
462         assert_eq!(slice, &[1, 100, 3]);
463      });
464   }
465}