secure_types/
array.rs

1#[cfg(not(feature = "use_os"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "use_os")]
9use super::{alloc_aligned, page_aligned_size};
10#[cfg(feature = "use_os")]
11use memsec::Prot;
12
13/// A fixed-size array allocated in a secure memory region.
14///
15/// ## Security Model
16///
17/// When compiled with the `use_os` feature (the default), it provides several layers of protection:
18/// - **Zeroization on Drop**: The memory is zeroized when the array is dropped.
19/// - **Memory Locking**: The underlying memory pages are locked using `mlock` & `madvise` for (Unix) or
20///   `VirtualLock` & `VirtualProtect` for (Windows) to prevent the OS from memory-dump/swap to disk or other processes accessing the memory.
21/// - **Memory Encryption**: On Windows, the memory is also encrypted using `CryptProtectMemory`.
22///
23/// In a `no_std` environment, it falls back to providing only the **zeroization-on-drop** guarantee.
24///
25/// # Program Termination
26///
27/// Direct indexing (e.g., `array[0]`) on a locked array will cause the operating system
28/// to terminate the process with an access violation error. Always use the provided
29/// scope methods (`unlock`, `unlock_mut`) for safe access.
30///
31/// # Notes
32///
33/// If you return a new allocated `[T; LENGTH]` from one of the unlock methods you are responsible for zeroizing the memory.
34///
35/// # Example
36///
37/// ```
38/// use secure_types::{SecureArray, Zeroize};
39///
40/// let exposed_key: &mut [u8; 32] = &mut [1u8; 32];
41/// let secure_key: SecureArray<u8, 32> = SecureArray::from_slice_mut(exposed_key).unwrap();
42///
43/// secure_key.unlock(|unlocked_slice| {
44///     assert_eq!(unlocked_slice.len(), 32);
45///     assert_eq!(unlocked_slice[0], 1);
46/// });
47///
48/// // Not recommended but if you allocate a new [u8; LENGTH] make sure to zeroize it
49/// let mut exposed = secure_key.unlock(|unlocked_slice| {
50///     [unlocked_slice[0], unlocked_slice[1], unlocked_slice[2]]
51/// });
52///
53/// // Do what you need to to do with the new array
54/// // When you are done with it, zeroize it
55/// exposed.zeroize();
56/// ```
57pub struct SecureArray<T, const LENGTH: usize>
58where
59   T: Zeroize,
60{
61   ptr: NonNull<T>,
62   _marker: PhantomData<T>,
63}
64
65unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
66unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
67
68impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
69where
70   T: Zeroize,
71{
72   /// Creates an empty (but allocated) SecureArray.
73   ///
74   /// The memory is allocated but not initialized, and it's the caller's responsibility to fill it.
75   pub fn empty() -> Result<Self, Error> {
76      let size = LENGTH * mem::size_of::<T>();
77      if size == 0 {
78         // Cannot create a zero-sized secure array
79         return Err(Error::LengthCannotBeZero);
80      }
81
82      let ptr = alloc_aligned::<T>(size)?;
83
84      let secure_array = SecureArray {
85         ptr,
86         _marker: PhantomData,
87      };
88
89      let (encrypted, locked) = secure_array.lock_memory();
90
91      #[cfg(feature = "use_os")]
92      if !locked {
93         return Err(Error::LockFailed);
94      }
95
96      #[cfg(feature = "use_os")]
97      if !encrypted {
98         return Err(Error::CryptProtectMemoryFailed);
99      }
100
101      Ok(secure_array)
102   }
103
104   /// Creates a new SecureArray from a `&mut [T; LENGTH]`.
105   ///
106   /// The passed slice is zeroized afterwards
107   pub fn from_slice_mut(content: &mut [T; LENGTH]) -> Result<Self, Error> {
108      let secure_array = Self::empty()?;
109
110      secure_array.unlock_memory();
111
112      unsafe {
113         // Copy the data from the source array into the secure memory region
114         core::ptr::copy_nonoverlapping(
115            content.as_ptr(),
116            secure_array.ptr.as_ptr(),
117            LENGTH,
118         );
119      }
120
121      content.zeroize();
122
123      let (encrypted, locked) = secure_array.lock_memory();
124
125      #[cfg(feature = "use_os")]
126      if !locked {
127         return Err(Error::LockFailed);
128      }
129
130      #[cfg(feature = "use_os")]
131      if !encrypted {
132         return Err(Error::CryptProtectMemoryFailed);
133      }
134
135      Ok(secure_array)
136   }
137
138   /// Creates a new SecureArray from a `&[T; LENGTH]`.
139   ///
140   /// The array is not zeroized, you are responsible for zeroizing it
141   pub fn from_slice(content: &[T; LENGTH]) -> Result<Self, Error> {
142      let secure_array = Self::empty()?;
143
144      secure_array.unlock_memory();
145
146      unsafe {
147         // Copy the data from the source array into the secure memory region
148         core::ptr::copy_nonoverlapping(
149            content.as_ptr(),
150            secure_array.ptr.as_ptr(),
151            LENGTH,
152         );
153      }
154
155      let (encrypted, locked) = secure_array.lock_memory();
156
157      #[cfg(feature = "use_os")]
158      if !locked {
159         return Err(Error::LockFailed);
160      }
161
162      #[cfg(feature = "use_os")]
163      if !encrypted {
164         return Err(Error::CryptProtectMemoryFailed);
165      }
166
167      Ok(secure_array)
168   }
169
170   pub fn len(&self) -> usize {
171      LENGTH
172   }
173
174   pub fn is_empty(&self) -> bool {
175      self.len() == 0
176   }
177
178   pub fn as_ptr(&self) -> *const T {
179      self.ptr.as_ptr()
180   }
181
182   pub fn as_mut_ptr(&mut self) -> *mut u8 {
183      self.ptr.as_ptr() as *mut u8
184   }
185
186   #[allow(dead_code)]
187   fn aligned_size(&self) -> usize {
188      let size = self.len() * mem::size_of::<T>();
189      #[cfg(feature = "use_os")]
190      {
191         unsafe { page_aligned_size(size) }
192      }
193      #[cfg(not(feature = "use_os"))]
194      {
195         size // No page alignment in no_std
196      }
197   }
198
199   #[cfg(all(feature = "use_os", windows))]
200   fn encypt_memory(&self) -> bool {
201      let ptr = self.as_ptr() as *mut u8;
202      super::crypt_protect_memory(ptr, self.aligned_size())
203   }
204
205   #[cfg(all(feature = "use_os", windows))]
206   fn decrypt_memory(&self) -> bool {
207      let ptr = self.as_ptr() as *mut u8;
208      super::crypt_unprotect_memory(ptr, self.aligned_size())
209   }
210
211   pub(crate) fn lock_memory(&self) -> (bool, bool) {
212      #[cfg(feature = "use_os")]
213      {
214         #[cfg(windows)]
215         {
216            let encrypt_ok = self.encypt_memory();
217            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
218            (encrypt_ok, mprotect_ok)
219         }
220         #[cfg(unix)]
221         {
222            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
223            (true, mprotect_ok)
224         }
225      }
226      #[cfg(not(feature = "use_os"))]
227      {
228         (true, true) // No-op: always "succeeds"
229      }
230   }
231
232   pub(crate) fn unlock_memory(&self) -> (bool, bool) {
233      #[cfg(feature = "use_os")]
234      {
235         #[cfg(windows)]
236         {
237            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
238            if !mprotect_ok {
239               return (false, false);
240            }
241            let decrypt_ok = self.decrypt_memory();
242            (decrypt_ok, mprotect_ok)
243         }
244         #[cfg(unix)]
245         {
246            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
247            (true, mprotect_ok)
248         }
249      }
250
251      #[cfg(not(feature = "use_os"))]
252      {
253         (true, true) // No-op: always "succeeds"
254      }
255   }
256
257   /// Immutable access to the array's data as a `&[T]`
258   pub fn unlock<F, R>(&self, f: F) -> R
259   where
260      F: FnOnce(&[T]) -> R,
261   {
262      self.unlock_memory();
263      let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
264      let result = f(slice);
265      self.lock_memory();
266      result
267   }
268
269   /// Mutable access to the array's data as a `&mut [T]`
270   pub fn unlock_mut<F, R>(&mut self, f: F) -> R
271   where
272      F: FnOnce(&mut [T]) -> R,
273   {
274      self.unlock_memory();
275      let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
276      let result = f(slice);
277      self.lock_memory();
278      result
279   }
280
281   /// Securely erases the contents of the array by zeroizing the memory.
282   pub fn erase(&mut self) {
283      self.unlock_mut(|slice| {
284         for element in slice.iter_mut() {
285            element.zeroize();
286         }
287      });
288   }
289}
290
291impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
292   type Output = T;
293   fn index(&self, index: usize) -> &Self::Output {
294      assert!(index < self.len(), "Index out of bounds");
295      unsafe {
296         let ptr = self.ptr.as_ptr().add(index);
297         &*ptr
298      }
299   }
300}
301
302impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
303   fn index_mut(&mut self, index: usize) -> &mut Self::Output {
304      assert!(index < self.len(), "Index out of bounds");
305      unsafe {
306         let ptr = self.ptr.as_ptr().add(index);
307         &mut *ptr
308      }
309   }
310}
311
312impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
313   fn drop(&mut self) {
314      self.erase();
315      self.unlock_memory();
316
317      let size = LENGTH * mem::size_of::<T>();
318      if size == 0 {
319         return;
320      }
321
322      unsafe {
323         #[cfg(feature = "use_os")]
324         {
325            memsec::free(self.ptr);
326         }
327         #[cfg(not(feature = "use_os"))]
328         {
329            let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
330            dealloc(self.ptr.as_ptr() as *mut u8, layout);
331         }
332      }
333   }
334}
335
336impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
337   fn clone(&self) -> Self {
338      let mut new_array = Self::empty().unwrap();
339      self.unlock(|src_slice| {
340         new_array.unlock_mut(|dest_slice| {
341            dest_slice.clone_from_slice(src_slice);
342         });
343      });
344      new_array
345   }
346}
347
348impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
349   type Error = Error;
350
351   /// Tries to convert a `SecureVec<u8>` into a `SecureArray<u8, LENGTH>`.
352   ///
353   /// This operation will only succeed if `vec.len() == LENGTH`.
354   ///
355   /// The `SecureVec` is consumed.
356   fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
357      if vec.len() != LENGTH {
358         return Err(Error::LengthMismatch);
359      }
360
361      let mut new_array = Self::empty()?;
362
363      vec.unlock_slice(|vec_slice| {
364         new_array.unlock_mut(|array_slice| {
365            array_slice.copy_from_slice(vec_slice);
366         });
367      });
368
369      Ok(new_array)
370   }
371}
372
373#[cfg(feature = "serde")]
374impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
375   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
376   where
377      S: serde::Serializer,
378   {
379      self.unlock(|slice| serializer.collect_seq(slice.iter()))
380   }
381}
382
383#[cfg(feature = "serde")]
384impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
385   fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
386   where
387      D: serde::Deserializer<'de>,
388   {
389      struct SecureArrayVisitor<const L: usize>;
390
391      impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
392         type Value = SecureArray<u8, L>;
393
394         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
395            write!(formatter, "a byte array of length {}", L)
396         }
397
398         fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
399         where
400            A: serde::de::SeqAccess<'de>,
401         {
402            let mut data: SecureVec<u8> =
403               SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
404            while let Some(byte) = seq.next_element()? {
405               data.push(byte);
406            }
407
408            // Check that the deserialized data has the exact length required.
409            if data.len() != L {
410               return Err(serde::de::Error::invalid_length(
411                  data.len(),
412                  &self,
413               ));
414            }
415
416            SecureArray::try_from(data).map_err(serde::de::Error::custom)
417         }
418      }
419
420      deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
421   }
422}
423
424#[cfg(all(test, feature = "use_os"))]
425mod tests {
426   use super::*;
427   use std::process::{Command, Stdio};
428   use std::sync::{Arc, Mutex};
429
430   #[test]
431   fn test_creation() {
432      let exposed_mut = &mut [1, 2, 3];
433      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_mut).unwrap();
434      assert_eq!(array.len(), 3);
435
436      array.unlock(|slice| {
437         assert_eq!(slice, &[1, 2, 3]);
438      });
439
440      assert_eq!(exposed_mut, &[0u8; 3]);
441
442      let exposed = &[1, 2, 3];
443      let array: SecureArray<u8, 3> = SecureArray::from_slice(exposed).unwrap();
444      assert_eq!(array.len(), 3);
445
446      array.unlock(|slice| {
447         assert_eq!(slice, &[1, 2, 3]);
448      });
449
450      assert_eq!(exposed, &[1, 2, 3]);
451   }
452
453   #[test]
454   fn test_from_secure_vec() {
455      let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
456      let array: SecureArray<u8, 3> = vec.try_into().unwrap();
457      assert_eq!(array.len(), 3);
458      array.unlock(|slice| {
459         assert_eq!(slice, &[1, 2, 3]);
460      });
461   }
462
463   #[test]
464   fn test_erase() {
465      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
466      let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
467      array.erase();
468      array.unlock(|slice| {
469         assert_eq!(slice, &[0u8; 3]);
470      });
471   }
472
473   #[test]
474   fn test_size_cannot_be_zero() {
475      let secure: SecureArray<u8, 3> = SecureArray::from_slice(&[1, 2, 3]).unwrap();
476      let size = secure.aligned_size();
477      assert_eq!(size > 0, true);
478
479      let secure: SecureArray<u8, 3> = SecureArray::empty().unwrap();
480      let size = secure.aligned_size();
481      assert_eq!(size > 0, true);
482   }
483
484   #[test]
485   #[should_panic]
486   fn test_length_cannot_be_zero() {
487      let secure_vec = SecureVec::new().unwrap();
488      let _secure_array: SecureArray<u8, 0> = SecureArray::try_from(secure_vec).unwrap();
489   }
490
491   #[test]
492   fn lock_unlock() {
493      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
494      let secure: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
495      let size = secure.aligned_size();
496      assert_eq!(size > 0, true);
497
498      let (decrypted, unlocked) = secure.unlock_memory();
499      assert!(decrypted);
500      assert!(unlocked);
501
502      let (encrypted, locked) = secure.lock_memory();
503      assert!(encrypted);
504      assert!(locked);
505   }
506
507   #[test]
508   fn test_clone() {
509      let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
510      array1.unlock_mut(|slice| {
511         slice[0] = 1;
512         slice[1] = 2;
513         slice[2] = 3;
514      });
515
516      let array2 = array1.clone();
517
518      array2.unlock(|slice| {
519         assert_eq!(slice, &[1, 2, 3]);
520      });
521
522      array1.unlock(|slice| {
523         assert_eq!(slice, &[1, 2, 3]);
524      });
525   }
526
527   #[test]
528   fn test_thread_safety() {
529      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
530      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
531      let arc_array = Arc::new(Mutex::new(array));
532      let mut handles = Vec::new();
533
534      for _ in 0..5u8 {
535         let array_clone = Arc::clone(&arc_array);
536         let handle = std::thread::spawn(move || {
537            let mut guard = array_clone.lock().unwrap();
538            guard.unlock_mut(|slice| {
539               slice[0] += 1;
540            });
541         });
542         handles.push(handle);
543      }
544
545      for handle in handles {
546         handle.join().unwrap();
547      }
548
549      let final_array = arc_array.lock().unwrap();
550      final_array.unlock(|slice| {
551         assert_eq!(slice[0], 6);
552         assert_eq!(slice[1], 2);
553         assert_eq!(slice[2], 3);
554      });
555   }
556
557   #[test]
558   fn test_index_should_fail_when_locked() {
559      let arg = "CRASH_TEST_ARRAY_LOCKED";
560
561      if std::env::args().any(|a| a == arg) {
562         let exposed: &mut [u8; 3] = &mut [1, 2, 3];
563         let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
564         let _value = core::hint::black_box(array[0]);
565
566         std::process::exit(1);
567      }
568
569      let child = Command::new(std::env::current_exe().unwrap())
570         .arg("array::tests::test_index_should_fail_when_locked")
571         .arg(arg)
572         .arg("--nocapture")
573         .stdout(Stdio::piped())
574         .stderr(Stdio::piped())
575         .spawn()
576         .expect("Failed to spawn child process");
577
578      let output = child.wait_with_output().expect("Failed to wait on child");
579      let status = output.status;
580
581      assert!(
582         !status.success(),
583         "Process exited successfully with code {:?}, but it should have crashed.",
584         status.code()
585      );
586
587      #[cfg(unix)]
588      {
589         use std::os::unix::process::ExitStatusExt;
590         let signal = status
591            .signal()
592            .expect("Process was not terminated by a signal on Unix.");
593         assert!(
594            signal == libc::SIGSEGV || signal == libc::SIGBUS,
595            "Process terminated with unexpected signal: {}",
596            signal
597         );
598         println!(
599            "Test passed: Process correctly terminated with signal {}.",
600            signal
601         );
602      }
603
604      #[cfg(windows)]
605      {
606         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
607         assert_eq!(
608            status.code(),
609            Some(STATUS_ACCESS_VIOLATION),
610            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
611            status.code()
612         );
613         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
614      }
615   }
616
617   #[test]
618   fn test_unlock_mut() {
619      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
620      let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
621
622      array.unlock_mut(|slice| {
623         slice[1] = 100;
624      });
625
626      array.unlock(|slice| {
627         assert_eq!(slice, &[1, 100, 3]);
628      });
629   }
630
631   #[cfg(feature = "serde")]
632   #[test]
633   fn test_serde() {
634      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
635      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
636      let json_string = serde_json::to_string(&array).expect("Serialization failed");
637      let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
638
639      let deserialized_string: SecureArray<u8, 3> =
640         serde_json::from_str(&json_string).expect("Deserialization failed");
641
642      let deserialized_bytes: SecureArray<u8, 3> =
643         serde_json::from_slice(&json_bytes).expect("Deserialization failed");
644
645      deserialized_string.unlock(|slice| {
646         assert_eq!(slice, &[1, 2, 3]);
647      });
648
649      deserialized_bytes.unlock(|slice| {
650         assert_eq!(slice, &[1, 2, 3]);
651      });
652   }
653}