1#[cfg(not(feature = "use_os"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "use_os")]
9use super::{alloc_aligned, page_aligned_size};
10#[cfg(feature = "use_os")]
11use memsec::Prot;
12
13pub struct SecureArray<T, const LENGTH: usize>
58where
59 T: Zeroize,
60{
61 ptr: NonNull<T>,
62 _marker: PhantomData<T>,
63}
64
65unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
66unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
67
68impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
69where
70 T: Zeroize,
71{
72 pub fn empty() -> Result<Self, Error> {
76 let size = LENGTH * mem::size_of::<T>();
77 if size == 0 {
78 return Err(Error::LengthCannotBeZero);
80 }
81
82 let ptr = alloc_aligned::<T>(size)?;
83
84 let secure_array = SecureArray {
85 ptr,
86 _marker: PhantomData,
87 };
88
89 let (encrypted, locked) = secure_array.lock_memory();
90
91 #[cfg(feature = "use_os")]
92 if !locked {
93 return Err(Error::LockFailed);
94 }
95
96 #[cfg(feature = "use_os")]
97 if !encrypted {
98 return Err(Error::CryptProtectMemoryFailed);
99 }
100
101 Ok(secure_array)
102 }
103
104 pub fn from_slice_mut(content: &mut [T; LENGTH]) -> Result<Self, Error> {
108 let secure_array = Self::empty()?;
109
110 secure_array.unlock_memory();
111
112 unsafe {
113 core::ptr::copy_nonoverlapping(
115 content.as_ptr(),
116 secure_array.ptr.as_ptr(),
117 LENGTH,
118 );
119 }
120
121 content.zeroize();
122
123 let (encrypted, locked) = secure_array.lock_memory();
124
125 #[cfg(feature = "use_os")]
126 if !locked {
127 return Err(Error::LockFailed);
128 }
129
130 #[cfg(feature = "use_os")]
131 if !encrypted {
132 return Err(Error::CryptProtectMemoryFailed);
133 }
134
135 Ok(secure_array)
136 }
137
138 pub fn from_slice(content: &[T; LENGTH]) -> Result<Self, Error> {
142 let secure_array = Self::empty()?;
143
144 secure_array.unlock_memory();
145
146 unsafe {
147 core::ptr::copy_nonoverlapping(
149 content.as_ptr(),
150 secure_array.ptr.as_ptr(),
151 LENGTH,
152 );
153 }
154
155 let (encrypted, locked) = secure_array.lock_memory();
156
157 #[cfg(feature = "use_os")]
158 if !locked {
159 return Err(Error::LockFailed);
160 }
161
162 #[cfg(feature = "use_os")]
163 if !encrypted {
164 return Err(Error::CryptProtectMemoryFailed);
165 }
166
167 Ok(secure_array)
168 }
169
170 pub fn len(&self) -> usize {
171 LENGTH
172 }
173
174 pub fn is_empty(&self) -> bool {
175 self.len() == 0
176 }
177
178 pub fn as_ptr(&self) -> *const T {
179 self.ptr.as_ptr()
180 }
181
182 pub fn as_mut_ptr(&mut self) -> *mut u8 {
183 self.ptr.as_ptr() as *mut u8
184 }
185
186 #[allow(dead_code)]
187 fn aligned_size(&self) -> usize {
188 let size = self.len() * mem::size_of::<T>();
189 #[cfg(feature = "use_os")]
190 {
191 unsafe { page_aligned_size(size) }
192 }
193 #[cfg(not(feature = "use_os"))]
194 {
195 size }
197 }
198
199 #[cfg(all(feature = "use_os", windows))]
200 fn encypt_memory(&self) -> bool {
201 let ptr = self.as_ptr() as *mut u8;
202 super::crypt_protect_memory(ptr, self.aligned_size())
203 }
204
205 #[cfg(all(feature = "use_os", windows))]
206 fn decrypt_memory(&self) -> bool {
207 let ptr = self.as_ptr() as *mut u8;
208 super::crypt_unprotect_memory(ptr, self.aligned_size())
209 }
210
211 pub(crate) fn lock_memory(&self) -> (bool, bool) {
212 #[cfg(feature = "use_os")]
213 {
214 #[cfg(windows)]
215 {
216 let encrypt_ok = self.encypt_memory();
217 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
218 (encrypt_ok, mprotect_ok)
219 }
220 #[cfg(unix)]
221 {
222 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
223 (true, mprotect_ok)
224 }
225 }
226 #[cfg(not(feature = "use_os"))]
227 {
228 (true, true) }
230 }
231
232 pub(crate) fn unlock_memory(&self) -> (bool, bool) {
233 #[cfg(feature = "use_os")]
234 {
235 #[cfg(windows)]
236 {
237 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
238 if !mprotect_ok {
239 return (false, false);
240 }
241 let decrypt_ok = self.decrypt_memory();
242 (decrypt_ok, mprotect_ok)
243 }
244 #[cfg(unix)]
245 {
246 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
247 (true, mprotect_ok)
248 }
249 }
250
251 #[cfg(not(feature = "use_os"))]
252 {
253 (true, true) }
255 }
256
257 pub fn unlock<F, R>(&self, f: F) -> R
259 where
260 F: FnOnce(&[T]) -> R,
261 {
262 self.unlock_memory();
263 let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
264 let result = f(slice);
265 self.lock_memory();
266 result
267 }
268
269 pub fn unlock_mut<F, R>(&mut self, f: F) -> R
271 where
272 F: FnOnce(&mut [T]) -> R,
273 {
274 self.unlock_memory();
275 let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
276 let result = f(slice);
277 self.lock_memory();
278 result
279 }
280
281 pub fn erase(&mut self) {
283 self.unlock_mut(|slice| {
284 for element in slice.iter_mut() {
285 element.zeroize();
286 }
287 });
288 }
289}
290
291impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
292 type Output = T;
293 fn index(&self, index: usize) -> &Self::Output {
294 assert!(index < self.len(), "Index out of bounds");
295 unsafe {
296 let ptr = self.ptr.as_ptr().add(index);
297 &*ptr
298 }
299 }
300}
301
302impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
303 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
304 assert!(index < self.len(), "Index out of bounds");
305 unsafe {
306 let ptr = self.ptr.as_ptr().add(index);
307 &mut *ptr
308 }
309 }
310}
311
312impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
313 fn drop(&mut self) {
314 self.erase();
315 self.unlock_memory();
316
317 let size = LENGTH * mem::size_of::<T>();
318 if size == 0 {
319 return;
320 }
321
322 unsafe {
323 #[cfg(feature = "use_os")]
324 {
325 memsec::free(self.ptr);
326 }
327 #[cfg(not(feature = "use_os"))]
328 {
329 let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
330 dealloc(self.ptr.as_ptr() as *mut u8, layout);
331 }
332 }
333 }
334}
335
336impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
337 fn clone(&self) -> Self {
338 let mut new_array = Self::empty().unwrap();
339 self.unlock(|src_slice| {
340 new_array.unlock_mut(|dest_slice| {
341 dest_slice.clone_from_slice(src_slice);
342 });
343 });
344 new_array
345 }
346}
347
348impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
349 type Error = Error;
350
351 fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
357 if vec.len() != LENGTH {
358 return Err(Error::LengthMismatch);
359 }
360
361 let mut new_array = Self::empty()?;
362
363 vec.unlock_slice(|vec_slice| {
364 new_array.unlock_mut(|array_slice| {
365 array_slice.copy_from_slice(vec_slice);
366 });
367 });
368
369 Ok(new_array)
370 }
371}
372
373#[cfg(feature = "serde")]
374impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
375 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
376 where
377 S: serde::Serializer,
378 {
379 self.unlock(|slice| serializer.collect_seq(slice.iter()))
380 }
381}
382
383#[cfg(feature = "serde")]
384impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
385 fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
386 where
387 D: serde::Deserializer<'de>,
388 {
389 struct SecureArrayVisitor<const L: usize>;
390
391 impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
392 type Value = SecureArray<u8, L>;
393
394 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
395 write!(formatter, "a byte array of length {}", L)
396 }
397
398 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
399 where
400 A: serde::de::SeqAccess<'de>,
401 {
402 let mut data: SecureVec<u8> =
403 SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
404 while let Some(byte) = seq.next_element()? {
405 data.push(byte);
406 }
407
408 if data.len() != L {
410 return Err(serde::de::Error::invalid_length(
411 data.len(),
412 &self,
413 ));
414 }
415
416 SecureArray::try_from(data).map_err(serde::de::Error::custom)
417 }
418 }
419
420 deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
421 }
422}
423
424#[cfg(all(test, feature = "use_os"))]
425mod tests {
426 use super::*;
427 use std::process::{Command, Stdio};
428 use std::sync::{Arc, Mutex};
429
430 #[test]
431 fn test_creation() {
432 let exposed_mut = &mut [1, 2, 3];
433 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_mut).unwrap();
434 assert_eq!(array.len(), 3);
435
436 array.unlock(|slice| {
437 assert_eq!(slice, &[1, 2, 3]);
438 });
439
440 assert_eq!(exposed_mut, &[0u8; 3]);
441
442 let exposed = &[1, 2, 3];
443 let array: SecureArray<u8, 3> = SecureArray::from_slice(exposed).unwrap();
444 assert_eq!(array.len(), 3);
445
446 array.unlock(|slice| {
447 assert_eq!(slice, &[1, 2, 3]);
448 });
449
450 assert_eq!(exposed, &[1, 2, 3]);
451 }
452
453 #[test]
454 fn test_from_secure_vec() {
455 let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
456 let array: SecureArray<u8, 3> = vec.try_into().unwrap();
457 assert_eq!(array.len(), 3);
458 array.unlock(|slice| {
459 assert_eq!(slice, &[1, 2, 3]);
460 });
461 }
462
463 #[test]
464 fn test_erase() {
465 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
466 let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
467 array.erase();
468 array.unlock(|slice| {
469 assert_eq!(slice, &[0u8; 3]);
470 });
471 }
472
473 #[test]
474 fn test_size_cannot_be_zero() {
475 let secure: SecureArray<u8, 3> = SecureArray::from_slice(&[1, 2, 3]).unwrap();
476 let size = secure.aligned_size();
477 assert_eq!(size > 0, true);
478
479 let secure: SecureArray<u8, 3> = SecureArray::empty().unwrap();
480 let size = secure.aligned_size();
481 assert_eq!(size > 0, true);
482 }
483
484 #[test]
485 #[should_panic]
486 fn test_length_cannot_be_zero() {
487 let secure_vec = SecureVec::new().unwrap();
488 let _secure_array: SecureArray<u8, 0> = SecureArray::try_from(secure_vec).unwrap();
489 }
490
491 #[test]
492 fn lock_unlock() {
493 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
494 let secure: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
495 let size = secure.aligned_size();
496 assert_eq!(size > 0, true);
497
498 let (decrypted, unlocked) = secure.unlock_memory();
499 assert!(decrypted);
500 assert!(unlocked);
501
502 let (encrypted, locked) = secure.lock_memory();
503 assert!(encrypted);
504 assert!(locked);
505 }
506
507 #[test]
508 fn test_clone() {
509 let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
510 array1.unlock_mut(|slice| {
511 slice[0] = 1;
512 slice[1] = 2;
513 slice[2] = 3;
514 });
515
516 let array2 = array1.clone();
517
518 array2.unlock(|slice| {
519 assert_eq!(slice, &[1, 2, 3]);
520 });
521
522 array1.unlock(|slice| {
523 assert_eq!(slice, &[1, 2, 3]);
524 });
525 }
526
527 #[test]
528 fn test_thread_safety() {
529 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
530 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
531 let arc_array = Arc::new(Mutex::new(array));
532 let mut handles = Vec::new();
533
534 for _ in 0..5u8 {
535 let array_clone = Arc::clone(&arc_array);
536 let handle = std::thread::spawn(move || {
537 let mut guard = array_clone.lock().unwrap();
538 guard.unlock_mut(|slice| {
539 slice[0] += 1;
540 });
541 });
542 handles.push(handle);
543 }
544
545 for handle in handles {
546 handle.join().unwrap();
547 }
548
549 let final_array = arc_array.lock().unwrap();
550 final_array.unlock(|slice| {
551 assert_eq!(slice[0], 6);
552 assert_eq!(slice[1], 2);
553 assert_eq!(slice[2], 3);
554 });
555 }
556
557 #[test]
558 fn test_index_should_fail_when_locked() {
559 let arg = "CRASH_TEST_ARRAY_LOCKED";
560
561 if std::env::args().any(|a| a == arg) {
562 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
563 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
564 let _value = core::hint::black_box(array[0]);
565
566 std::process::exit(1);
567 }
568
569 let child = Command::new(std::env::current_exe().unwrap())
570 .arg("array::tests::test_index_should_fail_when_locked")
571 .arg(arg)
572 .arg("--nocapture")
573 .stdout(Stdio::piped())
574 .stderr(Stdio::piped())
575 .spawn()
576 .expect("Failed to spawn child process");
577
578 let output = child.wait_with_output().expect("Failed to wait on child");
579 let status = output.status;
580
581 assert!(
582 !status.success(),
583 "Process exited successfully with code {:?}, but it should have crashed.",
584 status.code()
585 );
586
587 #[cfg(unix)]
588 {
589 use std::os::unix::process::ExitStatusExt;
590 let signal = status
591 .signal()
592 .expect("Process was not terminated by a signal on Unix.");
593 assert!(
594 signal == libc::SIGSEGV || signal == libc::SIGBUS,
595 "Process terminated with unexpected signal: {}",
596 signal
597 );
598 println!(
599 "Test passed: Process correctly terminated with signal {}.",
600 signal
601 );
602 }
603
604 #[cfg(windows)]
605 {
606 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
607 assert_eq!(
608 status.code(),
609 Some(STATUS_ACCESS_VIOLATION),
610 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
611 status.code()
612 );
613 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
614 }
615 }
616
617 #[test]
618 fn test_unlock_mut() {
619 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
620 let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
621
622 array.unlock_mut(|slice| {
623 slice[1] = 100;
624 });
625
626 array.unlock(|slice| {
627 assert_eq!(slice, &[1, 100, 3]);
628 });
629 }
630
631 #[cfg(feature = "serde")]
632 #[test]
633 fn test_serde() {
634 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
635 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
636 let json_string = serde_json::to_string(&array).expect("Serialization failed");
637 let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
638
639 let deserialized_string: SecureArray<u8, 3> =
640 serde_json::from_str(&json_string).expect("Deserialization failed");
641
642 let deserialized_bytes: SecureArray<u8, 3> =
643 serde_json::from_slice(&json_bytes).expect("Deserialization failed");
644
645 deserialized_string.unlock(|slice| {
646 assert_eq!(slice, &[1, 2, 3]);
647 });
648
649 deserialized_bytes.unlock(|slice| {
650 assert_eq!(slice, &[1, 2, 3]);
651 });
652 }
653}