secure_types/
array.rs

1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "std")]
9use super::page_aligned_size;
10#[cfg(feature = "std")]
11use memsec::Prot;
12
13/// A fixed-size array allocated in a secure memory region.
14///
15/// ## Security Model
16///
17/// When compiled with the `std` feature (the default), it provides several layers of protection:
18/// - **Zeroization on Drop**: The memory is zeroized when the array is dropped.
19/// - **Memory Locking**: The underlying memory pages are locked using `mlock` & `madvise` for (Unix) or
20///   `VirtualLock` & `VirtualProtect` for (Windows) to prevent the OS from memory-dump/swap to disk or other processes accessing the memory.
21/// - **Memory Encryption**: On Windows, the memory is also encrypted using `CryptProtectMemory`.
22///
23/// In a `no_std` environment, it falls back to providing only the **zeroization-on-drop** guarantee.
24///
25/// # Program Termination
26///
27/// Direct indexing (e.g., `array[0]`) on a locked array will cause the operating system
28/// to terminate the process with an access violation error. Always use the provided
29/// scope methods (`unlock`, `unlock_mut`) for safe access.
30/// 
31/// # Notes
32/// 
33/// If you return a new allocated `[T; LENGTH]` from one of the unlock methods you are responsible for zeroizing the memory.
34///
35/// # Example
36///
37/// ```
38/// use secure_types::{SecureArray, Zeroize};
39///
40/// let exposed_key: &mut [u8; 32] = &mut [1u8; 32];
41/// let secure_key: SecureArray<u8, 32> = SecureArray::from_slice_mut(exposed_key).unwrap();
42///
43/// secure_key.unlock(|unlocked_slice| {
44///     assert_eq!(unlocked_slice.len(), 32);
45///     assert_eq!(unlocked_slice[0], 1);
46/// });
47/// 
48/// // Not recommended but if you allocate a new [u8; LENGTH] make sure to zeroize it
49/// let mut exposed = secure_key.unlock(|unlocked_slice| {
50///     [unlocked_slice[0], unlocked_slice[1], unlocked_slice[2]]
51/// });
52/// 
53/// // Do what you need to to do with the new array
54/// // When you are done with it, zeroize it
55/// exposed.zeroize();
56/// ```
57pub struct SecureArray<T, const LENGTH: usize>
58where
59   T: Zeroize,
60{
61   ptr: NonNull<T>,
62   _marker: PhantomData<T>,
63}
64
65unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
66unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
67
68impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
69where
70   T: Zeroize,
71{
72   /// Creates an empty (but allocated) SecureArray.
73   ///
74   /// The memory is allocated but not initialized, and it's the caller's responsibility to fill it.
75   pub fn empty() -> Result<Self, Error> {
76      let size = LENGTH * mem::size_of::<T>();
77      if size == 0 {
78         // Cannot create a zero-sized secure array
79         return Err(Error::LengthCannotBeZero);
80      }
81
82      #[cfg(feature = "std")]
83      let new_ptr = {
84         let aligned_size = page_aligned_size(size);
85         let allocated_ptr = unsafe { memsec::malloc_sized(aligned_size) };
86         allocated_ptr.ok_or(Error::AllocationFailed)?.as_ptr() as *mut T
87      };
88
89      #[cfg(not(feature = "std"))]
90      let new_ptr = {
91         let layout = Layout::from_size_align(size, mem::align_of::<T>())
92            .map_err(|_| Error::AllocationFailed)?;
93         let ptr = unsafe { alloc::alloc(layout) as *mut T };
94         if ptr.is_null() {
95            return Err(Error::AllocationFailed);
96         }
97         ptr
98      };
99
100      let non_null = NonNull::new(new_ptr).ok_or(Error::NullAllocation)?;
101
102      let secure_array = SecureArray {
103         ptr: non_null,
104         _marker: PhantomData,
105      };
106
107      let (encrypted, locked) = secure_array.lock_memory();
108
109      #[cfg(feature = "std")]
110      if !locked {
111         return Err(Error::LockFailed);
112      }
113
114      #[cfg(feature = "std")]
115      if !encrypted {
116         return Err(Error::CryptProtectMemoryFailed);
117      }
118
119      Ok(secure_array)
120   }
121
122   /// Creates a new SecureArray from a `&mut [T; LENGTH]`.
123   ///
124   /// The passed slice is zeroized afterwards
125   pub fn from_slice_mut(content: &mut [T; LENGTH]) -> Result<Self, Error> {
126      let secure_array = Self::empty()?;
127
128      secure_array.unlock_memory();
129
130      unsafe {
131         // Copy the data from the source array into the secure memory region
132         core::ptr::copy_nonoverlapping(
133            content.as_ptr(),
134            secure_array.ptr.as_ptr(),
135            LENGTH,
136         );
137      }
138
139      content.zeroize();
140
141      let (encrypted, locked) = secure_array.lock_memory();
142
143      #[cfg(feature = "std")]
144      if !locked {
145         return Err(Error::LockFailed);
146      }
147
148      #[cfg(feature = "std")]
149      if !encrypted {
150         return Err(Error::CryptProtectMemoryFailed);
151      }
152
153      Ok(secure_array)
154   }
155
156   /// Creates a new SecureArray from a `&[T; LENGTH]`.
157   ///
158   /// The array is not zeroized, you are responsible for zeroizing it
159   pub fn from_slice(content: &[T; LENGTH]) -> Result<Self, Error> {
160      let secure_array = Self::empty()?;
161
162      secure_array.unlock_memory();
163
164      unsafe {
165         // Copy the data from the source array into the secure memory region
166         core::ptr::copy_nonoverlapping(
167            content.as_ptr(),
168            secure_array.ptr.as_ptr(),
169            LENGTH,
170         );
171      }
172
173      let (encrypted, locked) = secure_array.lock_memory();
174
175      #[cfg(feature = "std")]
176      if !locked {
177         return Err(Error::LockFailed);
178      }
179
180      #[cfg(feature = "std")]
181      if !encrypted {
182         return Err(Error::CryptProtectMemoryFailed);
183      }
184
185      Ok(secure_array)
186   }
187
188   pub fn len(&self) -> usize {
189      LENGTH
190   }
191
192   pub fn is_empty(&self) -> bool {
193      self.len() == 0
194   }
195
196   pub fn as_ptr(&self) -> *const T {
197      self.ptr.as_ptr()
198   }
199
200   pub fn as_mut_ptr(&mut self) -> *mut u8 {
201      self.ptr.as_ptr() as *mut u8
202   }
203
204   #[allow(dead_code)]
205   fn aligned_size(&self) -> usize {
206      let size = self.len() * mem::size_of::<T>();
207      #[cfg(feature = "std")]
208      {
209         page_aligned_size(size)
210      }
211      #[cfg(not(feature = "std"))]
212      {
213         size // No page alignment in no_std
214      }
215   }
216
217   #[cfg(all(feature = "std", windows))]
218   fn encypt_memory(&self) -> bool {
219      let ptr = self.as_ptr() as *mut u8;
220      super::crypt_protect_memory(ptr, self.aligned_size())
221   }
222
223   #[cfg(all(feature = "std", windows))]
224   fn decrypt_memory(&self) -> bool {
225      let ptr = self.as_ptr() as *mut u8;
226      super::crypt_unprotect_memory(ptr, self.aligned_size())
227   }
228
229   pub(crate) fn lock_memory(&self) -> (bool, bool) {
230      #[cfg(feature = "std")]
231      {
232         #[cfg(windows)]
233         {
234            let encrypt_ok = self.encypt_memory();
235            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
236            (encrypt_ok, mprotect_ok)
237         }
238         #[cfg(unix)]
239         {
240            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
241            (true, mprotect_ok)
242         }
243      }
244      #[cfg(not(feature = "std"))]
245      {
246         (true, true) // No-op: always "succeeds"
247      }
248   }
249
250   pub(crate) fn unlock_memory(&self) -> (bool, bool) {
251      #[cfg(feature = "std")]
252      {
253         #[cfg(windows)]
254         {
255            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
256            if !mprotect_ok {
257               return (false, false);
258            }
259            let decrypt_ok = self.decrypt_memory();
260            (decrypt_ok, mprotect_ok)
261         }
262         #[cfg(unix)]
263         {
264            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
265            (true, mprotect_ok)
266         }
267      }
268
269      #[cfg(not(feature = "std"))]
270      {
271         (true, true) // No-op: always "succeeds"
272      }
273   }
274
275   /// Immutable access to the array's data as a `&[T]`
276   pub fn unlock<F, R>(&self, f: F) -> R
277   where
278      F: FnOnce(&[T]) -> R,
279   {
280      self.unlock_memory();
281      let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
282      let result = f(slice);
283      self.lock_memory();
284      result
285   }
286
287   /// Mutable access to the array's data as a `&mut [T]`
288   pub fn unlock_mut<F, R>(&mut self, f: F) -> R
289   where
290      F: FnOnce(&mut [T]) -> R,
291   {
292      self.unlock_memory();
293      let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
294      let result = f(slice);
295      self.lock_memory();
296      result
297   }
298
299   /// Securely erases the contents of the array by zeroizing the memory.
300   pub fn erase(&mut self) {
301      self.unlock_mut(|slice| {
302         for element in slice.iter_mut() {
303            element.zeroize();
304         }
305      });
306   }
307}
308
309impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
310   type Output = T;
311   fn index(&self, index: usize) -> &Self::Output {
312      assert!(index < self.len(), "Index out of bounds");
313      unsafe {
314         let ptr = self.ptr.as_ptr().add(index);
315         &*ptr
316      }
317   }
318}
319
320impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
321   fn index_mut(&mut self, index: usize) -> &mut Self::Output {
322      assert!(index < self.len(), "Index out of bounds");
323      unsafe {
324         let ptr = self.ptr.as_ptr().add(index);
325         &mut *ptr
326      }
327   }
328}
329
330impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
331   fn drop(&mut self) {
332      self.erase();
333      self.unlock_memory();
334
335      let size = LENGTH * mem::size_of::<T>();
336      if size == 0 {
337         return;
338      }
339
340      unsafe {
341         #[cfg(feature = "std")]
342         {
343            memsec::free(self.ptr);
344         }
345         #[cfg(not(feature = "std"))]
346         {
347            // Recreate the layout to deallocate correctly
348            let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
349            dealloc(self.ptr.as_ptr() as *mut u8, layout);
350         }
351      }
352   }
353}
354
355impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
356   fn clone(&self) -> Self {
357      let mut new_array = Self::empty().unwrap();
358      self.unlock(|src_slice| {
359         new_array.unlock_mut(|dest_slice| {
360            dest_slice.clone_from_slice(src_slice);
361         });
362      });
363      new_array
364   }
365}
366
367impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
368   type Error = Error;
369
370   /// Tries to convert a `SecureVec<u8>` into a `SecureArray<u8, LENGTH>`.
371   ///
372   /// This operation will only succeed if `vec.len() == LENGTH`.
373   /// 
374   /// The `SecureVec` is consumed.
375   fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
376      if vec.len() != LENGTH {
377         return Err(Error::LengthMismatch);
378      }
379
380      let mut new_array = Self::empty()?;
381
382      vec.unlock_slice(|vec_slice| {
383         new_array.unlock_mut(|array_slice| {
384            array_slice.copy_from_slice(vec_slice);
385         });
386      });
387
388      Ok(new_array)
389   }
390}
391
392#[cfg(feature = "serde")]
393impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
394   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
395   where
396      S: serde::Serializer,
397   {
398      self.unlock(|slice| serializer.collect_seq(slice.iter()))
399   }
400}
401
402#[cfg(feature = "serde")]
403impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
404   fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
405   where
406      D: serde::Deserializer<'de>,
407   {
408      struct SecureArrayVisitor<const L: usize>;
409
410      impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
411         type Value = SecureArray<u8, L>;
412
413         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
414            write!(formatter, "a byte array of length {}", L)
415         }
416
417         fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
418         where
419            A: serde::de::SeqAccess<'de>,
420         {
421            let mut data: SecureVec<u8> =
422               SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
423            while let Some(byte) = seq.next_element()? {
424               data.push(byte);
425            }
426
427            // Check that the deserialized data has the exact length required.
428            if data.len() != L {
429               return Err(serde::de::Error::invalid_length(
430                  data.len(),
431                  &self,
432               ));
433            }
434
435            SecureArray::try_from(data).map_err(serde::de::Error::custom)
436         }
437      }
438
439      deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
440   }
441}
442
443#[cfg(all(test, feature = "std"))]
444mod tests {
445   use super::*;
446   use std::process::{Command, Stdio};
447   use std::sync::{Arc, Mutex};
448
449   #[test]
450   fn test_creation() {
451      let exposed_mut = &mut [1, 2, 3];
452      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_mut).unwrap();
453      assert_eq!(array.len(), 3);
454
455      array.unlock(|slice| {
456         assert_eq!(slice, &[1, 2, 3]);
457      });
458
459      assert_eq!(exposed_mut, &[0u8; 3]);
460
461      let exposed = &[1, 2, 3];
462      let array: SecureArray<u8, 3> = SecureArray::from_slice(exposed).unwrap();
463      assert_eq!(array.len(), 3);
464
465      array.unlock(|slice| {
466         assert_eq!(slice, &[1, 2, 3]);
467      });
468
469      assert_eq!(exposed, &[1, 2, 3]);
470   }
471
472   #[test]
473   fn test_from_secure_vec() {
474      let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
475      let array: SecureArray<u8, 3> = vec.try_into().unwrap();
476      assert_eq!(array.len(), 3);
477      array.unlock(|slice| {
478         assert_eq!(slice, &[1, 2, 3]);
479      });
480   }
481
482   #[test]
483   fn test_erase() {
484      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
485      let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
486      array.erase();
487      array.unlock(|slice| {
488         assert_eq!(slice, &[0u8; 3]);
489      });
490   }
491
492   #[test]
493   fn test_size_cannot_be_zero() {
494      let secure: SecureArray<u8, 3> = SecureArray::from_slice(&[1, 2, 3]).unwrap();
495      let size = secure.aligned_size();
496      assert_eq!(size > 0, true);
497
498      let secure: SecureArray<u8, 3> = SecureArray::empty().unwrap();
499      let size = secure.aligned_size();
500      assert_eq!(size > 0, true);
501   }
502
503   #[test]
504   #[should_panic]
505   fn test_length_cannot_be_zero() {
506      let secure_vec = SecureVec::new().unwrap();
507      let _secure_array: SecureArray<u8, 0> = SecureArray::try_from(secure_vec).unwrap();
508   }
509
510   #[test]
511   fn lock_unlock() {
512      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
513      let secure: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
514      let size = secure.aligned_size();
515      assert_eq!(size > 0, true);
516
517      let (decrypted, unlocked) = secure.unlock_memory();
518      assert!(decrypted);
519      assert!(unlocked);
520
521      let (encrypted, locked) = secure.lock_memory();
522      assert!(encrypted);
523      assert!(locked);
524   }
525
526   #[test]
527   fn test_clone() {
528      let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
529      array1.unlock_mut(|slice| {
530         slice[0] = 1;
531         slice[1] = 2;
532         slice[2] = 3;
533      });
534
535      let array2 = array1.clone();
536
537      array2.unlock(|slice| {
538         assert_eq!(slice, &[1, 2, 3]);
539      });
540
541      array1.unlock(|slice| {
542         assert_eq!(slice, &[1, 2, 3]);
543      });
544   }
545
546   #[test]
547   fn test_thread_safety() {
548      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
549      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
550      let arc_array = Arc::new(Mutex::new(array));
551      let mut handles = Vec::new();
552
553      for _ in 0..5u8 {
554         let array_clone = Arc::clone(&arc_array);
555         let handle = std::thread::spawn(move || {
556            let mut guard = array_clone.lock().unwrap();
557            guard.unlock_mut(|slice| {
558               slice[0] += 1;
559            });
560         });
561         handles.push(handle);
562      }
563
564      for handle in handles {
565         handle.join().unwrap();
566      }
567
568      let final_array = arc_array.lock().unwrap();
569      final_array.unlock(|slice| {
570         assert_eq!(slice[0], 6);
571         assert_eq!(slice[1], 2);
572         assert_eq!(slice[2], 3);
573      });
574   }
575
576   #[test]
577   fn test_index_should_fail_when_locked() {
578      let arg = "CRASH_TEST_ARRAY_LOCKED";
579
580      if std::env::args().any(|a| a == arg) {
581         let exposed: &mut [u8; 3] = &mut [1, 2, 3];
582         let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
583         let _value = core::hint::black_box(array[0]);
584
585         std::process::exit(1);
586      }
587
588      let child = Command::new(std::env::current_exe().unwrap())
589         .arg("array::tests::test_index_should_fail_when_locked")
590         .arg(arg)
591         .arg("--nocapture")
592         .stdout(Stdio::piped())
593         .stderr(Stdio::piped())
594         .spawn()
595         .expect("Failed to spawn child process");
596
597      let output = child.wait_with_output().expect("Failed to wait on child");
598      let status = output.status;
599
600      assert!(
601         !status.success(),
602         "Process exited successfully with code {:?}, but it should have crashed.",
603         status.code()
604      );
605
606      #[cfg(unix)]
607      {
608         use std::os::unix::process::ExitStatusExt;
609         let signal = status
610            .signal()
611            .expect("Process was not terminated by a signal on Unix.");
612         assert!(
613            signal == libc::SIGSEGV || signal == libc::SIGBUS,
614            "Process terminated with unexpected signal: {}",
615            signal
616         );
617         println!(
618            "Test passed: Process correctly terminated with signal {}.",
619            signal
620         );
621      }
622
623      #[cfg(windows)]
624      {
625         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
626         assert_eq!(
627            status.code(),
628            Some(STATUS_ACCESS_VIOLATION),
629            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
630            status.code()
631         );
632         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
633      }
634   }
635
636   #[test]
637   fn test_unlock_mut() {
638      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
639      let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
640
641      array.unlock_mut(|slice| {
642         slice[1] = 100;
643      });
644
645      array.unlock(|slice| {
646         assert_eq!(slice, &[1, 100, 3]);
647      });
648   }
649
650   #[cfg(feature = "serde")]
651   #[test]
652   fn test_serde() {
653      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
654      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
655      let json_string = serde_json::to_string(&array).expect("Serialization failed");
656      let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
657
658      let deserialized_string: SecureArray<u8, 3> =
659         serde_json::from_str(&json_string).expect("Deserialization failed");
660
661      let deserialized_bytes: SecureArray<u8, 3> =
662         serde_json::from_slice(&json_bytes).expect("Deserialization failed");
663
664      deserialized_string.unlock(|slice| {
665         assert_eq!(slice, &[1, 2, 3]);
666      });
667
668      deserialized_bytes.unlock(|slice| {
669         assert_eq!(slice, &[1, 2, 3]);
670      });
671   }
672}