secure_types/
array.rs

1#[cfg(not(feature = "use_os"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "use_os")]
9use super::{alloc, free};
10#[cfg(feature = "use_os")]
11use memsec::Prot;
12
13/// A fixed-size array allocated in a secure memory region.
14///
15/// ## Security Model
16///
17/// When compiled with the `use_os` feature (the default), it provides several layers of protection:
18/// - **Zeroization on Drop**: The memory is zeroized when the array is dropped.
19/// - **Memory Locking**: The underlying memory pages are locked using `mlock` & `madvise` for (Unix) or
20///   `VirtualLock` & `VirtualProtect` for (Windows) to prevent the OS from memory-dump/swap to disk or other processes accessing the memory.
21///
22/// In a `no_std` environment, it falls back to providing only the **zeroization-on-drop** guarantee.
23///
24/// # Program Termination
25///
26/// Direct indexing (e.g., `array[0]`) on a locked array will cause the operating system
27/// to terminate the process with an access violation error. Always use the provided
28/// scope methods (`unlock`, `unlock_mut`) for safe access.
29///
30/// # Notes
31///
32/// If you return a new allocated `[T; LENGTH]` from one of the unlock methods you are responsible for zeroizing the memory.
33///
34/// # Example
35///
36/// ```
37/// use secure_types::{SecureArray, Zeroize};
38///
39/// let exposed_key: &mut [u8; 32] = &mut [1u8; 32];
40/// let secure_key: SecureArray<u8, 32> = SecureArray::from_slice_mut(exposed_key).unwrap();
41///
42/// secure_key.unlock(|unlocked_slice| {
43///     assert_eq!(unlocked_slice.len(), 32);
44///     assert_eq!(unlocked_slice[0], 1);
45/// });
46///
47/// // Not recommended but if you allocate a new [u8; LENGTH] make sure to zeroize it
48/// let mut exposed = secure_key.unlock(|unlocked_slice| {
49///     [unlocked_slice[0], unlocked_slice[1], unlocked_slice[2]]
50/// });
51///
52/// // Do what you need to to do with the new array
53/// // When you are done with it, zeroize it
54/// exposed.zeroize();
55/// ```
56pub struct SecureArray<T, const LENGTH: usize>
57where
58   T: Zeroize,
59{
60   ptr: NonNull<T>,
61   _marker: PhantomData<T>,
62}
63
64unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
65unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
66
67impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
68where
69   T: Zeroize,
70{
71   /// Creates an empty (but allocated) SecureArray.
72   ///
73   /// The memory is allocated but not initialized, and it's the caller's responsibility to fill it.
74   pub fn empty() -> Result<Self, Error> {
75      let size = LENGTH * mem::size_of::<T>();
76      if size == 0 {
77         // Cannot create a zero-sized secure array
78         return Err(Error::LengthCannotBeZero);
79      }
80
81      let ptr = unsafe { alloc::<T>(size)? };
82
83      let secure_array = SecureArray {
84         ptr,
85         _marker: PhantomData,
86      };
87
88      let locked = secure_array.lock_memory();
89
90      #[cfg(feature = "use_os")]
91      if !locked {
92         return Err(Error::LockFailed);
93      }
94
95      Ok(secure_array)
96   }
97
98   /// Creates a new SecureArray from a `&mut [T; LENGTH]`.
99   ///
100   /// The passed slice is zeroized afterwards
101   pub fn from_slice_mut(content: &mut [T; LENGTH]) -> Result<Self, Error> {
102      let secure_array = Self::empty()?;
103
104      let unlocked = secure_array.unlock_memory();
105
106      if !unlocked {
107         return Err(Error::UnlockFailed);
108      }
109
110      unsafe {
111         // Copy the data from the source array into the secure memory region
112         core::ptr::copy_nonoverlapping(
113            content.as_ptr(),
114            secure_array.ptr.as_ptr(),
115            LENGTH,
116         );
117      }
118
119      content.zeroize();
120
121      let locked = secure_array.lock_memory();
122
123      #[cfg(feature = "use_os")]
124      if !locked {
125         return Err(Error::LockFailed);
126      }
127
128      Ok(secure_array)
129   }
130
131   /// Creates a new SecureArray from a `&[T; LENGTH]`.
132   ///
133   /// The array is not zeroized, you are responsible for zeroizing it
134   pub fn from_slice(content: &[T; LENGTH]) -> Result<Self, Error> {
135      let secure_array = Self::empty()?;
136
137      let unlocked = secure_array.unlock_memory();
138
139      if !unlocked {
140         return Err(Error::UnlockFailed);
141      }
142
143      unsafe {
144         // Copy the data from the source array into the secure memory region
145         core::ptr::copy_nonoverlapping(
146            content.as_ptr(),
147            secure_array.ptr.as_ptr(),
148            LENGTH,
149         );
150      }
151
152      let locked = secure_array.lock_memory();
153
154      #[cfg(feature = "use_os")]
155      if !locked {
156         return Err(Error::LockFailed);
157      }
158
159      Ok(secure_array)
160   }
161
162   pub fn len(&self) -> usize {
163      LENGTH
164   }
165
166   pub fn is_empty(&self) -> bool {
167      self.len() == 0
168   }
169
170   pub(crate) fn lock_memory(&self) -> bool {
171      #[cfg(feature = "use_os")]
172      {
173         #[cfg(windows)]
174         {
175            super::mprotect(self.ptr, Prot::NoAccess)
176         }
177         #[cfg(unix)]
178         {
179            super::mprotect(self.ptr, Prot::NoAccess)
180         }
181      }
182      #[cfg(not(feature = "use_os"))]
183      {
184         true // No-op: always "succeeds"
185      }
186   }
187
188   pub(crate) fn unlock_memory(&self) -> bool {
189      #[cfg(feature = "use_os")]
190      {
191         #[cfg(windows)]
192         {
193            super::mprotect(self.ptr, Prot::ReadWrite)
194         }
195         #[cfg(unix)]
196         {
197            super::mprotect(self.ptr, Prot::ReadWrite)
198         }
199      }
200
201      #[cfg(not(feature = "use_os"))]
202      {
203         true // No-op: always "succeeds"
204      }
205   }
206
207   /// Immutable access to the array's data as a `&[T]`
208   pub fn unlock<F, R>(&self, f: F) -> R
209   where
210      F: FnOnce(&[T]) -> R,
211   {
212      self.unlock_memory();
213      let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
214      let result = f(slice);
215      self.lock_memory();
216      result
217   }
218
219   /// Mutable access to the array's data as a `&mut [T]`
220   pub fn unlock_mut<F, R>(&mut self, f: F) -> R
221   where
222      F: FnOnce(&mut [T]) -> R,
223   {
224      self.unlock_memory();
225      let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
226      let result = f(slice);
227      self.lock_memory();
228      result
229   }
230
231   /// Securely erases the contents of the array by zeroizing the memory.
232   pub fn erase(&mut self) {
233      self.unlock_mut(|slice| {
234         for element in slice.iter_mut() {
235            element.zeroize();
236         }
237      });
238   }
239}
240
241impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
242   type Output = T;
243   fn index(&self, index: usize) -> &Self::Output {
244      assert!(index < self.len(), "Index out of bounds");
245      unsafe {
246         let ptr = self.ptr.as_ptr().add(index);
247         &*ptr
248      }
249   }
250}
251
252impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
253   fn index_mut(&mut self, index: usize) -> &mut Self::Output {
254      assert!(index < self.len(), "Index out of bounds");
255      unsafe {
256         let ptr = self.ptr.as_ptr().add(index);
257         &mut *ptr
258      }
259   }
260}
261
262impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
263   fn drop(&mut self) {
264      self.erase();
265      self.unlock_memory();
266
267      let size = LENGTH * mem::size_of::<T>();
268      if size == 0 {
269         return;
270      }
271
272      #[cfg(feature = "use_os")]
273      free(self.ptr);
274
275      #[cfg(not(feature = "use_os"))]
276      {
277         let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
278         dealloc(self.ptr.as_ptr() as *mut u8, layout);
279      }
280   }
281}
282
283impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
284   fn clone(&self) -> Self {
285      let mut new_array = Self::empty().unwrap();
286      self.unlock(|src_slice| {
287         new_array.unlock_mut(|dest_slice| {
288            dest_slice.clone_from_slice(src_slice);
289         });
290      });
291      new_array
292   }
293}
294
295impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
296   type Error = Error;
297
298   /// Tries to convert a `SecureVec<u8>` into a `SecureArray<u8, LENGTH>`.
299   ///
300   /// This operation will only succeed if `vec.len() == LENGTH`.
301   ///
302   /// The `SecureVec` is consumed.
303   fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
304      if vec.len() != LENGTH {
305         return Err(Error::LengthMismatch);
306      }
307
308      let mut new_array = Self::empty()?;
309
310      vec.unlock_slice(|vec_slice| {
311         new_array.unlock_mut(|array_slice| {
312            array_slice.copy_from_slice(vec_slice);
313         });
314      });
315
316      Ok(new_array)
317   }
318}
319
320#[cfg(feature = "serde")]
321impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
322   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
323   where
324      S: serde::Serializer,
325   {
326      self.unlock(|slice| serializer.collect_seq(slice.iter()))
327   }
328}
329
330#[cfg(feature = "serde")]
331impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
332   fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
333   where
334      D: serde::Deserializer<'de>,
335   {
336      struct SecureArrayVisitor<const L: usize>;
337
338      impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
339         type Value = SecureArray<u8, L>;
340
341         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
342            write!(formatter, "a byte array of length {}", L)
343         }
344
345         fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
346         where
347            A: serde::de::SeqAccess<'de>,
348         {
349            let mut data: SecureVec<u8> =
350               SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
351            while let Some(byte) = seq.next_element()? {
352               data.push(byte);
353            }
354
355            // Check that the deserialized data has the exact length required.
356            if data.len() != L {
357               return Err(serde::de::Error::invalid_length(
358                  data.len(),
359                  &self,
360               ));
361            }
362
363            SecureArray::try_from(data).map_err(serde::de::Error::custom)
364         }
365      }
366
367      deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
368   }
369}
370
371#[cfg(all(test, feature = "use_os"))]
372mod tests {
373   use super::*;
374   use std::process::{Command, Stdio};
375   use std::sync::{Arc, Mutex};
376
377   #[test]
378   fn test_creation() {
379      let exposed_mut = &mut [1, 2, 3];
380      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_mut).unwrap();
381      assert_eq!(array.len(), 3);
382
383      array.unlock(|slice| {
384         assert_eq!(slice, &[1, 2, 3]);
385      });
386
387      assert_eq!(exposed_mut, &[0u8; 3]);
388
389      let exposed = &[1, 2, 3];
390      let array: SecureArray<u8, 3> = SecureArray::from_slice(exposed).unwrap();
391      assert_eq!(array.len(), 3);
392
393      array.unlock(|slice| {
394         assert_eq!(slice, &[1, 2, 3]);
395      });
396
397      assert_eq!(exposed, &[1, 2, 3]);
398   }
399
400   #[test]
401   fn test_from_secure_vec() {
402      let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
403      let array: SecureArray<u8, 3> = vec.try_into().unwrap();
404      assert_eq!(array.len(), 3);
405      array.unlock(|slice| {
406         assert_eq!(slice, &[1, 2, 3]);
407      });
408   }
409
410   #[test]
411   fn test_erase() {
412      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
413      let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
414      array.erase();
415      array.unlock(|slice| {
416         assert_eq!(slice, &[0u8; 3]);
417      });
418   }
419
420   #[test]
421   #[should_panic]
422   fn test_length_cannot_be_zero() {
423      let secure_vec = SecureVec::new().unwrap();
424      let _secure_array: SecureArray<u8, 0> = SecureArray::try_from(secure_vec).unwrap();
425   }
426
427   #[test]
428   fn lock_unlock() {
429      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
430      let secure: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
431
432      let unlocked = secure.unlock_memory();
433      assert!(unlocked);
434
435      let locked = secure.lock_memory();
436      assert!(locked);
437   }
438
439   #[test]
440   fn test_clone() {
441      let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
442      array1.unlock_mut(|slice| {
443         slice[0] = 1;
444         slice[1] = 2;
445         slice[2] = 3;
446      });
447
448      let array2 = array1.clone();
449
450      array2.unlock(|slice| {
451         assert_eq!(slice, &[1, 2, 3]);
452      });
453
454      array1.unlock(|slice| {
455         assert_eq!(slice, &[1, 2, 3]);
456      });
457   }
458
459   #[test]
460   fn test_thread_safety() {
461      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
462      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
463      let arc_array = Arc::new(Mutex::new(array));
464      let mut handles = Vec::new();
465
466      for _ in 0..5u8 {
467         let array_clone = Arc::clone(&arc_array);
468         let handle = std::thread::spawn(move || {
469            let mut guard = array_clone.lock().unwrap();
470            guard.unlock_mut(|slice| {
471               slice[0] += 1;
472            });
473         });
474         handles.push(handle);
475      }
476
477      for handle in handles {
478         handle.join().unwrap();
479      }
480
481      let final_array = arc_array.lock().unwrap();
482      final_array.unlock(|slice| {
483         assert_eq!(slice[0], 6);
484         assert_eq!(slice[1], 2);
485         assert_eq!(slice[2], 3);
486      });
487   }
488
489   #[test]
490   fn test_index_should_fail_when_locked() {
491      let arg = "CRASH_TEST_ARRAY_LOCKED";
492
493      if std::env::args().any(|a| a == arg) {
494         let exposed: &mut [u8; 3] = &mut [1, 2, 3];
495         let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
496         let _value = core::hint::black_box(array[0]);
497
498         std::process::exit(1);
499      }
500
501      let child = Command::new(std::env::current_exe().unwrap())
502         .arg("array::tests::test_index_should_fail_when_locked")
503         .arg(arg)
504         .arg("--nocapture")
505         .stdout(Stdio::piped())
506         .stderr(Stdio::piped())
507         .spawn()
508         .expect("Failed to spawn child process");
509
510      let output = child.wait_with_output().expect("Failed to wait on child");
511      let status = output.status;
512
513      assert!(
514         !status.success(),
515         "Process exited successfully with code {:?}, but it should have crashed.",
516         status.code()
517      );
518
519      #[cfg(unix)]
520      {
521         use std::os::unix::process::ExitStatusExt;
522         let signal = status
523            .signal()
524            .expect("Process was not terminated by a signal on Unix.");
525         assert!(
526            signal == libc::SIGSEGV || signal == libc::SIGBUS,
527            "Process terminated with unexpected signal: {}",
528            signal
529         );
530         println!(
531            "Test passed: Process correctly terminated with signal {}.",
532            signal
533         );
534      }
535
536      #[cfg(windows)]
537      {
538         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
539         assert_eq!(
540            status.code(),
541            Some(STATUS_ACCESS_VIOLATION),
542            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
543            status.code()
544         );
545         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
546      }
547   }
548
549   #[test]
550   fn test_unlock_mut() {
551      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
552      let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
553
554      array.unlock_mut(|slice| {
555         slice[1] = 100;
556      });
557
558      array.unlock(|slice| {
559         assert_eq!(slice, &[1, 100, 3]);
560      });
561   }
562
563   #[cfg(feature = "serde")]
564   #[test]
565   fn test_serde() {
566      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
567      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
568      let json_string = serde_json::to_string(&array).expect("Serialization failed");
569      let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
570
571      let deserialized_string: SecureArray<u8, 3> =
572         serde_json::from_str(&json_string).expect("Deserialization failed");
573
574      let deserialized_bytes: SecureArray<u8, 3> =
575         serde_json::from_slice(&json_bytes).expect("Deserialization failed");
576
577      deserialized_string.unlock(|slice| {
578         assert_eq!(slice, &[1, 2, 3]);
579      });
580
581      deserialized_bytes.unlock(|slice| {
582         assert_eq!(slice, &[1, 2, 3]);
583      });
584   }
585}