1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "std")]
9use super::page_aligned_size;
10#[cfg(feature = "std")]
11use memsec::Prot;
12
13pub struct SecureArray<T, const LENGTH: usize>
58where
59 T: Zeroize,
60{
61 ptr: NonNull<T>,
62 _marker: PhantomData<T>,
63}
64
65unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
66unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
67
68impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
69where
70 T: Zeroize,
71{
72 pub fn empty() -> Result<Self, Error> {
76 let size = LENGTH * mem::size_of::<T>();
77 if size == 0 {
78 return Err(Error::LengthCannotBeZero);
80 }
81
82 #[cfg(feature = "std")]
83 let new_ptr = {
84 let aligned_size = page_aligned_size(size);
85 let allocated_ptr = unsafe { memsec::malloc_sized(aligned_size) };
86 allocated_ptr.ok_or(Error::AllocationFailed)?.as_ptr() as *mut T
87 };
88
89 #[cfg(not(feature = "std"))]
90 let new_ptr = {
91 let layout = Layout::from_size_align(size, mem::align_of::<T>())
92 .map_err(|_| Error::AllocationFailed)?;
93 let ptr = unsafe { alloc::alloc(layout) as *mut T };
94 if ptr.is_null() {
95 return Err(Error::AllocationFailed);
96 }
97 ptr
98 };
99
100 let non_null = NonNull::new(new_ptr).ok_or(Error::NullAllocation)?;
101
102 let secure_array = SecureArray {
103 ptr: non_null,
104 _marker: PhantomData,
105 };
106
107 let (encrypted, locked) = secure_array.lock_memory();
108
109 #[cfg(feature = "std")]
110 if !locked {
111 return Err(Error::LockFailed);
112 }
113
114 #[cfg(feature = "std")]
115 if !encrypted {
116 return Err(Error::CryptProtectMemoryFailed);
117 }
118
119 Ok(secure_array)
120 }
121
122 pub fn from_slice_mut(content: &mut [T; LENGTH]) -> Result<Self, Error> {
126 let secure_array = Self::empty()?;
127
128 secure_array.unlock_memory();
129
130 unsafe {
131 core::ptr::copy_nonoverlapping(
133 content.as_ptr(),
134 secure_array.ptr.as_ptr(),
135 LENGTH,
136 );
137 }
138
139 content.zeroize();
140
141 let (encrypted, locked) = secure_array.lock_memory();
142
143 #[cfg(feature = "std")]
144 if !locked {
145 return Err(Error::LockFailed);
146 }
147
148 #[cfg(feature = "std")]
149 if !encrypted {
150 return Err(Error::CryptProtectMemoryFailed);
151 }
152
153 Ok(secure_array)
154 }
155
156 pub fn from_slice(content: &[T; LENGTH]) -> Result<Self, Error> {
160 let secure_array = Self::empty()?;
161
162 secure_array.unlock_memory();
163
164 unsafe {
165 core::ptr::copy_nonoverlapping(
167 content.as_ptr(),
168 secure_array.ptr.as_ptr(),
169 LENGTH,
170 );
171 }
172
173 let (encrypted, locked) = secure_array.lock_memory();
174
175 #[cfg(feature = "std")]
176 if !locked {
177 return Err(Error::LockFailed);
178 }
179
180 #[cfg(feature = "std")]
181 if !encrypted {
182 return Err(Error::CryptProtectMemoryFailed);
183 }
184
185 Ok(secure_array)
186 }
187
188 pub fn len(&self) -> usize {
189 LENGTH
190 }
191
192 pub fn is_empty(&self) -> bool {
193 self.len() == 0
194 }
195
196 pub fn as_ptr(&self) -> *const T {
197 self.ptr.as_ptr()
198 }
199
200 pub fn as_mut_ptr(&mut self) -> *mut u8 {
201 self.ptr.as_ptr() as *mut u8
202 }
203
204 #[allow(dead_code)]
205 fn aligned_size(&self) -> usize {
206 let size = self.len() * mem::size_of::<T>();
207 #[cfg(feature = "std")]
208 {
209 page_aligned_size(size)
210 }
211 #[cfg(not(feature = "std"))]
212 {
213 size }
215 }
216
217 #[cfg(all(feature = "std", windows))]
218 fn encypt_memory(&self) -> bool {
219 let ptr = self.as_ptr() as *mut u8;
220 super::crypt_protect_memory(ptr, self.aligned_size())
221 }
222
223 #[cfg(all(feature = "std", windows))]
224 fn decrypt_memory(&self) -> bool {
225 let ptr = self.as_ptr() as *mut u8;
226 super::crypt_unprotect_memory(ptr, self.aligned_size())
227 }
228
229 pub(crate) fn lock_memory(&self) -> (bool, bool) {
230 #[cfg(feature = "std")]
231 {
232 #[cfg(windows)]
233 {
234 let encrypt_ok = self.encypt_memory();
235 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
236 (encrypt_ok, mprotect_ok)
237 }
238 #[cfg(unix)]
239 {
240 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
241 (true, mprotect_ok)
242 }
243 }
244 #[cfg(not(feature = "std"))]
245 {
246 (true, true) }
248 }
249
250 pub(crate) fn unlock_memory(&self) -> (bool, bool) {
251 #[cfg(feature = "std")]
252 {
253 #[cfg(windows)]
254 {
255 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
256 if !mprotect_ok {
257 return (false, false);
258 }
259 let decrypt_ok = self.decrypt_memory();
260 (decrypt_ok, mprotect_ok)
261 }
262 #[cfg(unix)]
263 {
264 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
265 (true, mprotect_ok)
266 }
267 }
268
269 #[cfg(not(feature = "std"))]
270 {
271 (true, true) }
273 }
274
275 pub fn unlock<F, R>(&self, f: F) -> R
277 where
278 F: FnOnce(&[T]) -> R,
279 {
280 self.unlock_memory();
281 let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
282 let result = f(slice);
283 self.lock_memory();
284 result
285 }
286
287 pub fn unlock_mut<F, R>(&mut self, f: F) -> R
289 where
290 F: FnOnce(&mut [T]) -> R,
291 {
292 self.unlock_memory();
293 let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
294 let result = f(slice);
295 self.lock_memory();
296 result
297 }
298
299 pub fn erase(&mut self) {
301 self.unlock_mut(|slice| {
302 for element in slice.iter_mut() {
303 element.zeroize();
304 }
305 });
306 }
307}
308
309impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
310 type Output = T;
311 fn index(&self, index: usize) -> &Self::Output {
312 assert!(index < self.len(), "Index out of bounds");
313 unsafe {
314 let ptr = self.ptr.as_ptr().add(index);
315 &*ptr
316 }
317 }
318}
319
320impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
321 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
322 assert!(index < self.len(), "Index out of bounds");
323 unsafe {
324 let ptr = self.ptr.as_ptr().add(index);
325 &mut *ptr
326 }
327 }
328}
329
330impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
331 fn drop(&mut self) {
332 self.erase();
333 self.unlock_memory();
334
335 let size = LENGTH * mem::size_of::<T>();
336 if size == 0 {
337 return;
338 }
339
340 unsafe {
341 #[cfg(feature = "std")]
342 {
343 memsec::free(self.ptr);
344 }
345 #[cfg(not(feature = "std"))]
346 {
347 let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
349 dealloc(self.ptr.as_ptr() as *mut u8, layout);
350 }
351 }
352 }
353}
354
355impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
356 fn clone(&self) -> Self {
357 let mut new_array = Self::empty().unwrap();
358 self.unlock(|src_slice| {
359 new_array.unlock_mut(|dest_slice| {
360 dest_slice.clone_from_slice(src_slice);
361 });
362 });
363 new_array
364 }
365}
366
367impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
368 type Error = Error;
369
370 fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
376 if vec.len() != LENGTH {
377 return Err(Error::LengthMismatch);
378 }
379
380 let mut new_array = Self::empty()?;
381
382 vec.unlock_slice(|vec_slice| {
383 new_array.unlock_mut(|array_slice| {
384 array_slice.copy_from_slice(vec_slice);
385 });
386 });
387
388 Ok(new_array)
389 }
390}
391
392#[cfg(feature = "serde")]
393impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
394 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
395 where
396 S: serde::Serializer,
397 {
398 self.unlock(|slice| serializer.collect_seq(slice.iter()))
399 }
400}
401
402#[cfg(feature = "serde")]
403impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
404 fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
405 where
406 D: serde::Deserializer<'de>,
407 {
408 struct SecureArrayVisitor<const L: usize>;
409
410 impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
411 type Value = SecureArray<u8, L>;
412
413 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
414 write!(formatter, "a byte array of length {}", L)
415 }
416
417 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
418 where
419 A: serde::de::SeqAccess<'de>,
420 {
421 let mut data: SecureVec<u8> =
422 SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
423 while let Some(byte) = seq.next_element()? {
424 data.push(byte);
425 }
426
427 if data.len() != L {
429 return Err(serde::de::Error::invalid_length(
430 data.len(),
431 &self,
432 ));
433 }
434
435 SecureArray::try_from(data).map_err(serde::de::Error::custom)
436 }
437 }
438
439 deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
440 }
441}
442
443#[cfg(all(test, feature = "std"))]
444mod tests {
445 use super::*;
446 use std::process::{Command, Stdio};
447 use std::sync::{Arc, Mutex};
448
449 #[test]
450 fn test_creation() {
451 let exposed_mut = &mut [1, 2, 3];
452 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_mut).unwrap();
453 assert_eq!(array.len(), 3);
454
455 array.unlock(|slice| {
456 assert_eq!(slice, &[1, 2, 3]);
457 });
458
459 assert_eq!(exposed_mut, &[0u8; 3]);
460
461 let exposed = &[1, 2, 3];
462 let array: SecureArray<u8, 3> = SecureArray::from_slice(exposed).unwrap();
463 assert_eq!(array.len(), 3);
464
465 array.unlock(|slice| {
466 assert_eq!(slice, &[1, 2, 3]);
467 });
468
469 assert_eq!(exposed, &[1, 2, 3]);
470 }
471
472 #[test]
473 fn test_from_secure_vec() {
474 let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
475 let array: SecureArray<u8, 3> = vec.try_into().unwrap();
476 assert_eq!(array.len(), 3);
477 array.unlock(|slice| {
478 assert_eq!(slice, &[1, 2, 3]);
479 });
480 }
481
482 #[test]
483 fn test_erase() {
484 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
485 let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
486 array.erase();
487 array.unlock(|slice| {
488 assert_eq!(slice, &[0u8; 3]);
489 });
490 }
491
492 #[test]
493 fn test_size_cannot_be_zero() {
494 let secure: SecureArray<u8, 3> = SecureArray::from_slice(&[1, 2, 3]).unwrap();
495 let size = secure.aligned_size();
496 assert_eq!(size > 0, true);
497
498 let secure: SecureArray<u8, 3> = SecureArray::empty().unwrap();
499 let size = secure.aligned_size();
500 assert_eq!(size > 0, true);
501 }
502
503 #[test]
504 #[should_panic]
505 fn test_length_cannot_be_zero() {
506 let secure_vec = SecureVec::new().unwrap();
507 let _secure_array: SecureArray<u8, 0> = SecureArray::try_from(secure_vec).unwrap();
508 }
509
510 #[test]
511 fn lock_unlock() {
512 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
513 let secure: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
514 let size = secure.aligned_size();
515 assert_eq!(size > 0, true);
516
517 let (decrypted, unlocked) = secure.unlock_memory();
518 assert!(decrypted);
519 assert!(unlocked);
520
521 let (encrypted, locked) = secure.lock_memory();
522 assert!(encrypted);
523 assert!(locked);
524 }
525
526 #[test]
527 fn test_clone() {
528 let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
529 array1.unlock_mut(|slice| {
530 slice[0] = 1;
531 slice[1] = 2;
532 slice[2] = 3;
533 });
534
535 let array2 = array1.clone();
536
537 array2.unlock(|slice| {
538 assert_eq!(slice, &[1, 2, 3]);
539 });
540
541 array1.unlock(|slice| {
542 assert_eq!(slice, &[1, 2, 3]);
543 });
544 }
545
546 #[test]
547 fn test_thread_safety() {
548 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
549 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
550 let arc_array = Arc::new(Mutex::new(array));
551 let mut handles = Vec::new();
552
553 for _ in 0..5u8 {
554 let array_clone = Arc::clone(&arc_array);
555 let handle = std::thread::spawn(move || {
556 let mut guard = array_clone.lock().unwrap();
557 guard.unlock_mut(|slice| {
558 slice[0] += 1;
559 });
560 });
561 handles.push(handle);
562 }
563
564 for handle in handles {
565 handle.join().unwrap();
566 }
567
568 let final_array = arc_array.lock().unwrap();
569 final_array.unlock(|slice| {
570 assert_eq!(slice[0], 6);
571 assert_eq!(slice[1], 2);
572 assert_eq!(slice[2], 3);
573 });
574 }
575
576 #[test]
577 fn test_index_should_fail_when_locked() {
578 let arg = "CRASH_TEST_ARRAY_LOCKED";
579
580 if std::env::args().any(|a| a == arg) {
581 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
582 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
583 let _value = core::hint::black_box(array[0]);
584
585 std::process::exit(1);
586 }
587
588 let child = Command::new(std::env::current_exe().unwrap())
589 .arg("array::tests::test_index_should_fail_when_locked")
590 .arg(arg)
591 .arg("--nocapture")
592 .stdout(Stdio::piped())
593 .stderr(Stdio::piped())
594 .spawn()
595 .expect("Failed to spawn child process");
596
597 let output = child.wait_with_output().expect("Failed to wait on child");
598 let status = output.status;
599
600 assert!(
601 !status.success(),
602 "Process exited successfully with code {:?}, but it should have crashed.",
603 status.code()
604 );
605
606 #[cfg(unix)]
607 {
608 use std::os::unix::process::ExitStatusExt;
609 let signal = status
610 .signal()
611 .expect("Process was not terminated by a signal on Unix.");
612 assert!(
613 signal == libc::SIGSEGV || signal == libc::SIGBUS,
614 "Process terminated with unexpected signal: {}",
615 signal
616 );
617 println!(
618 "Test passed: Process correctly terminated with signal {}.",
619 signal
620 );
621 }
622
623 #[cfg(windows)]
624 {
625 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
626 assert_eq!(
627 status.code(),
628 Some(STATUS_ACCESS_VIOLATION),
629 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
630 status.code()
631 );
632 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
633 }
634 }
635
636 #[test]
637 fn test_unlock_mut() {
638 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
639 let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
640
641 array.unlock_mut(|slice| {
642 slice[1] = 100;
643 });
644
645 array.unlock(|slice| {
646 assert_eq!(slice, &[1, 100, 3]);
647 });
648 }
649
650 #[cfg(feature = "serde")]
651 #[test]
652 fn test_serde() {
653 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
654 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
655 let json_string = serde_json::to_string(&array).expect("Serialization failed");
656 let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
657
658 let deserialized_string: SecureArray<u8, 3> =
659 serde_json::from_str(&json_string).expect("Deserialization failed");
660
661 let deserialized_bytes: SecureArray<u8, 3> =
662 serde_json::from_slice(&json_bytes).expect("Deserialization failed");
663
664 deserialized_string.unlock(|slice| {
665 assert_eq!(slice, &[1, 2, 3]);
666 });
667
668 deserialized_bytes.unlock(|slice| {
669 assert_eq!(slice, &[1, 2, 3]);
670 });
671 }
672}