secure_types/
array.rs

1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "std")]
9use super::page_size;
10#[cfg(feature = "std")]
11use memsec::Prot;
12
13/// A fixed-size array allocated in a secure memory region.
14///
15/// `SecureArray` provides the same core security guarantees as the other types in this
16/// crate, including zeroization on drop and optional memory locking/encryption when
17// compiled with the `std` feature.
18///
19/// It is ideal for secrets of a known, fixed length.
20///
21/// # Program Termination
22///
23/// Direct indexing (e.g., `array[0]`) on a locked array will cause the operating system
24/// to terminate the process with an access violation error. Always use the provided
25/// scope methods (`unlocked_scope`, `unlocked_mut_scope`) for safe access.
26///
27/// # Examples
28///
29/// ```
30/// use secure_types::SecureArray;
31///
32/// let key_data = [1u8; 32];
33/// let secure_key: SecureArray<u8, 32> = SecureArray::new(key_data).unwrap();
34///
35/// secure_key.unlocked_scope(|unlocked_slice| {
36///     assert_eq!(unlocked_slice.len(), 32);
37///     assert_eq!(unlocked_slice[0], 1);
38/// });
39/// ```
40pub struct SecureArray<T, const LENGTH: usize>
41where
42   T: Zeroize,
43{
44   ptr: NonNull<T>,
45   _marker: PhantomData<T>,
46}
47
48unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
49unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
50
51impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
52where
53   T: Zeroize,
54{
55   /// Creates an empty (but allocated) SecureArray.
56   /// The memory is allocated but not initialized, and it's the caller's responsibility to fill it.
57   pub fn empty() -> Result<Self, Error> {
58      let size = LENGTH * mem::size_of::<T>();
59      if size == 0 {
60         // Cannot create a zero-sized secure array
61         return Err(Error::AllocationFailed);
62      }
63
64      #[cfg(feature = "std")]
65      let new_ptr = {
66         let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
67         let allocated_ptr = unsafe { memsec::malloc_sized(aligned_size) };
68         allocated_ptr.ok_or(Error::AllocationFailed)?.as_ptr() as *mut T
69      };
70
71      #[cfg(not(feature = "std"))]
72      let new_ptr = {
73         let layout = Layout::from_size_align(size, mem::align_of::<T>())
74            .map_err(|_| Error::AllocationFailed)?;
75         let ptr = unsafe { alloc::alloc(layout) as *mut T };
76         if ptr.is_null() {
77            return Err(Error::AllocationFailed);
78         }
79         ptr
80      };
81
82      let non_null = NonNull::new(new_ptr).ok_or(Error::NullAllocation)?;
83
84      let secure_array = SecureArray {
85         ptr: non_null,
86         _marker: PhantomData,
87      };
88
89      let (encrypted, locked) = secure_array.lock_memory();
90
91      #[cfg(feature = "std")]
92      if !locked {
93         return Err(Error::LockFailed);
94      }
95
96      #[cfg(feature = "std")]
97      if !encrypted {
98         return Err(Error::CryptProtectMemoryFailed);
99      }
100
101      Ok(secure_array)
102   }
103
104   /// Creates a new SecureArray from a given array.
105   pub fn new(mut content: [T; LENGTH]) -> Result<Self, Error> {
106      let secure_array = Self::empty()?;
107
108      secure_array.unlock_memory();
109
110      unsafe {
111         // Copy the data from the source array into the secure memory region
112         core::ptr::copy_nonoverlapping(
113            content.as_ptr(),
114            secure_array.ptr.as_ptr(),
115            LENGTH,
116         );
117      }
118
119      content.zeroize();
120
121      let (encrypted, locked) = secure_array.lock_memory();
122
123      #[cfg(feature = "std")]
124      if !locked {
125         return Err(Error::LockFailed);
126      }
127
128      #[cfg(feature = "std")]
129      if !encrypted {
130         return Err(Error::CryptProtectMemoryFailed);
131      }
132
133      Ok(secure_array)
134   }
135
136   pub fn len(&self) -> usize {
137      LENGTH
138   }
139
140   pub fn as_ptr(&self) -> *const T {
141      self.ptr.as_ptr()
142   }
143
144   pub fn as_mut_ptr(&mut self) -> *mut u8 {
145      self.ptr.as_ptr() as *mut u8
146   }
147
148   #[allow(dead_code)]
149   fn allocated_byte_size(&self) -> usize {
150      let size = self.len() * mem::size_of::<T>();
151      #[cfg(feature = "std")]
152      {
153         (size + page_size() - 1) & !(page_size() - 1)
154      }
155      #[cfg(not(feature = "std"))]
156      {
157         size // No page alignment in no_std
158      }
159   }
160
161   #[cfg(all(feature = "std", windows))]
162   fn encypt_memory(&self) -> bool {
163      let ptr = self.as_ptr() as *mut u8;
164      super::crypt_protect_memory(ptr, self.allocated_byte_size())
165   }
166
167   #[cfg(all(feature = "std", windows))]
168   fn decrypt_memory(&self) -> bool {
169      let ptr = self.as_ptr() as *mut u8;
170      super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
171   }
172
173   pub(crate) fn lock_memory(&self) -> (bool, bool) {
174      #[cfg(feature = "std")]
175      {
176         #[cfg(windows)]
177         {
178            let encrypt_ok = self.encypt_memory();
179            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
180            (encrypt_ok, mprotect_ok)
181         }
182         #[cfg(unix)]
183         {
184            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
185            (true, mprotect_ok)
186         }
187      }
188      #[cfg(not(feature = "std"))]
189      {
190         (true, true) // No-op: always "succeeds"
191      }
192   }
193
194   pub(crate) fn unlock_memory(&self) -> (bool, bool) {
195      #[cfg(feature = "std")]
196      {
197         #[cfg(windows)]
198         {
199            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
200            if !mprotect_ok {
201               return (false, false);
202            }
203            let decrypt_ok = self.decrypt_memory();
204            (decrypt_ok, mprotect_ok)
205         }
206         #[cfg(unix)]
207         {
208            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
209            (true, mprotect_ok)
210         }
211      }
212
213      #[cfg(not(feature = "std"))]
214      {
215         (true, true) // No-op: always "succeeds"
216      }
217   }
218
219   /// Provides scoped, immutable access to the array's data as a slice.
220   pub fn unlocked_scope<F, R>(&self, f: F) -> R
221   where
222      F: FnOnce(&[T]) -> R,
223   {
224      self.unlock_memory();
225      let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
226      let result = f(slice);
227      self.lock_memory();
228      result
229   }
230
231   /// Provides scoped, mutable access to the array's data as a mutable slice.
232   pub fn unlocked_mut_scope<F, R>(&mut self, f: F) -> R
233   where
234      F: FnOnce(&mut [T]) -> R,
235   {
236      self.unlock_memory();
237      let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
238      let result = f(slice);
239      self.lock_memory();
240      result
241   }
242
243   /// Securely erases the contents of the array by zeroizing the memory.
244   pub fn erase(&mut self) {
245      self.unlocked_mut_scope(|slice| {
246         for element in slice.iter_mut() {
247            element.zeroize();
248         }
249      });
250   }
251}
252
253impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
254   type Output = T;
255   fn index(&self, index: usize) -> &Self::Output {
256      assert!(index < self.len(), "Index out of bounds");
257      unsafe {
258         let ptr = self.ptr.as_ptr().add(index);
259         let reference = &*ptr;
260         reference
261      }
262   }
263}
264
265impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
266   fn index_mut(&mut self, index: usize) -> &mut Self::Output {
267      assert!(index < self.len(), "Index out of bounds");
268      unsafe {
269         let ptr = self.ptr.as_ptr().add(index);
270         let reference = &mut *ptr;
271         reference
272      }
273   }
274}
275
276impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
277   fn drop(&mut self) {
278      self.erase();
279      self.unlock_memory();
280
281      let size = LENGTH * mem::size_of::<T>();
282      if size == 0 {
283         return;
284      }
285
286      unsafe {
287         #[cfg(feature = "std")]
288         {
289            memsec::free(self.ptr);
290         }
291         #[cfg(not(feature = "std"))]
292         {
293            // Recreate the layout to deallocate correctly
294            let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
295            dealloc(self.ptr.as_ptr() as *mut u8, layout);
296         }
297      }
298   }
299}
300
301impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
302   fn clone(&self) -> Self {
303      let mut new_array = Self::empty().unwrap();
304      self.unlocked_scope(|src_slice| {
305         new_array.unlocked_mut_scope(|dest_slice| {
306            dest_slice.clone_from_slice(src_slice);
307         });
308      });
309      new_array
310   }
311}
312
313impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
314   type Error = Error;
315
316   /// Tries to convert a `SecureVec<u8>` into a `SecureArray<u8, LENGTH>`.
317   ///
318   /// This operation will only succeed if `vec.len() == LENGTH`.
319   /// The original `SecureVec` is consumed.
320   fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
321      if vec.len() != LENGTH {
322         return Err(Error::LengthMismatch);
323      }
324
325      let mut new_array = Self::empty()?;
326
327      vec.slice_scope(|vec_slice| {
328         new_array.unlocked_mut_scope(|array_slice| {
329            array_slice.copy_from_slice(vec_slice);
330         });
331      });
332
333      Ok(new_array)
334   }
335}
336
337impl<T, const LENGTH: usize> TryFrom<[T; LENGTH]> for SecureArray<T, LENGTH>
338where
339   T: Zeroize,
340{
341   type Error = Error;
342   fn try_from(s: [T; LENGTH]) -> Result<Self, Error> {
343      Self::new(s)
344   }
345}
346
347#[cfg(feature = "serde")]
348impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
349   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
350   where
351      S: serde::Serializer,
352   {
353      self.unlocked_scope(|slice| serializer.collect_seq(slice.iter()))
354   }
355}
356
357#[cfg(feature = "serde")]
358impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
359   fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
360   where
361      D: serde::Deserializer<'de>,
362   {
363      struct SecureArrayVisitor<const L: usize>;
364
365      impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
366         type Value = SecureArray<u8, L>;
367
368         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
369            write!(formatter, "a byte array of length {}", L)
370         }
371
372         fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
373         where
374            A: serde::de::SeqAccess<'de>,
375         {
376            let mut data: SecureVec<u8> =
377               SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
378            while let Some(byte) = seq.next_element()? {
379               data.push(byte);
380            }
381
382            // Check that the deserialized data has the exact length required.
383            if data.len() != L {
384               return Err(serde::de::Error::invalid_length(
385                  data.len(),
386                  &self,
387               ));
388            }
389
390            SecureArray::try_from(data).map_err(serde::de::Error::custom)
391         }
392      }
393
394      deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
395   }
396}
397
398#[cfg(all(test, feature = "std"))]
399mod tests {
400   use super::*;
401   use std::process::{Command, Stdio};
402   use std::sync::{Arc, Mutex};
403
404   #[test]
405   fn test_creation() {
406      let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
407      assert_eq!(array.len(), 3);
408      array.unlocked_scope(|slice| {
409         assert_eq!(slice, &[1, 2, 3]);
410      });
411   }
412
413   #[test]
414   fn test_from_secure_vec() {
415      let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
416      let array: SecureArray<u8, 3> = vec.try_into().unwrap();
417      assert_eq!(array.len(), 3);
418      array.unlocked_scope(|slice| {
419         assert_eq!(slice, &[1, 2, 3]);
420      });
421   }
422
423   #[test]
424   fn test_erase() {
425      let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
426      array.erase();
427      array.unlocked_scope(|slice| {
428         assert_eq!(slice, &[0u8; 3]);
429      });
430   }
431
432   #[test]
433   fn test_clone() {
434      let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
435      array1.unlocked_mut_scope(|slice| {
436         slice[0] = 1;
437         slice[1] = 2;
438         slice[2] = 3;
439      });
440
441      let array2 = array1.clone();
442
443      array2.unlocked_scope(|slice| {
444         assert_eq!(slice, &[1, 2, 3]);
445      });
446
447      array1.unlocked_scope(|slice| {
448         assert_eq!(slice, &[1, 2, 3]);
449      });
450   }
451
452   #[test]
453   fn test_thread_safety() {
454      let array = SecureArray::new([1, 2, 3]).unwrap();
455      let arc_array = Arc::new(Mutex::new(array));
456      let mut handles = Vec::new();
457
458      for _ in 0..5u8 {
459         let array_clone = Arc::clone(&arc_array);
460         let handle = std::thread::spawn(move || {
461            let mut guard = array_clone.lock().unwrap();
462            guard.unlocked_mut_scope(|slice| {
463               slice[0] += 1;
464            });
465         });
466         handles.push(handle);
467      }
468
469      for handle in handles {
470         handle.join().unwrap();
471      }
472
473      let final_array = arc_array.lock().unwrap();
474      final_array.unlocked_scope(|slice| {
475         assert_eq!(slice[0], 6);
476      });
477   }
478
479   #[test]
480   fn test_index_should_fail_when_locked() {
481      let arg = "CRASH_TEST_ARRAY_LOCKED";
482
483      if std::env::args().any(|a| a == arg) {
484         let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
485         let _value = core::hint::black_box(array[0]);
486
487         std::process::exit(1);
488      }
489
490      let child = Command::new(std::env::current_exe().unwrap())
491         .arg("array::tests::test_index_should_fail_when_locked")
492         .arg(arg)
493         .arg("--nocapture")
494         .stdout(Stdio::piped())
495         .stderr(Stdio::piped())
496         .spawn()
497         .expect("Failed to spawn child process");
498
499      let output = child.wait_with_output().expect("Failed to wait on child");
500      let status = output.status;
501
502      assert!(
503         !status.success(),
504         "Process exited successfully with code {:?}, but it should have crashed.",
505         status.code()
506      );
507
508      #[cfg(unix)]
509      {
510         use std::os::unix::process::ExitStatusExt;
511         let signal = status
512            .signal()
513            .expect("Process was not terminated by a signal on Unix.");
514         assert!(
515            signal == libc::SIGSEGV || signal == libc::SIGBUS,
516            "Process terminated with unexpected signal: {}",
517            signal
518         );
519         println!(
520            "Test passed: Process correctly terminated with signal {}.",
521            signal
522         );
523      }
524
525      #[cfg(windows)]
526      {
527         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
528         assert_eq!(
529            status.code(),
530            Some(STATUS_ACCESS_VIOLATION),
531            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
532            status.code()
533         );
534         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
535      }
536   }
537
538   #[test]
539   fn test_mutable_access_in_scope() {
540      let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
541
542      array.unlocked_mut_scope(|slice| {
543         slice[1] = 100;
544      });
545
546      array.unlocked_scope(|slice| {
547         assert_eq!(slice, &[1, 100, 3]);
548      });
549   }
550
551   #[cfg(feature = "serde")]
552   #[test]
553   fn test_serde() {
554      let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
555      let json_string = serde_json::to_string(&array).expect("Serialization failed");
556      let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
557      
558      let deserialized_string: SecureArray<u8, 3> =
559         serde_json::from_str(&json_string).expect("Deserialization failed");
560
561      let deserialized_bytes: SecureArray<u8, 3> =
562         serde_json::from_slice(&json_bytes).expect("Deserialization failed");
563
564      deserialized_string.unlocked_scope(|slice| {
565         assert_eq!(slice, &[1, 2, 3]);
566      });
567
568      deserialized_bytes.unlocked_scope(|slice| {
569         assert_eq!(slice, &[1, 2, 3]);
570      });
571   }
572}