1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "std")]
9use super::page_size;
10#[cfg(feature = "std")]
11use memsec::Prot;
12
13pub struct SecureArray<T, const LENGTH: usize>
41where
42 T: Zeroize,
43{
44 ptr: NonNull<T>,
45 _marker: PhantomData<T>,
46}
47
48unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
49unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
50
51impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
52where
53 T: Zeroize,
54{
55 pub fn empty() -> Result<Self, Error> {
58 let size = LENGTH * mem::size_of::<T>();
59 if size == 0 {
60 return Err(Error::AllocationFailed);
62 }
63
64 #[cfg(feature = "std")]
65 let new_ptr = {
66 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
67 let allocated_ptr = unsafe { memsec::malloc_sized(aligned_size) };
68 allocated_ptr.ok_or(Error::AllocationFailed)?.as_ptr() as *mut T
69 };
70
71 #[cfg(not(feature = "std"))]
72 let new_ptr = {
73 let layout = Layout::from_size_align(size, mem::align_of::<T>())
74 .map_err(|_| Error::AllocationFailed)?;
75 let ptr = unsafe { alloc::alloc(layout) as *mut T };
76 if ptr.is_null() {
77 return Err(Error::AllocationFailed);
78 }
79 ptr
80 };
81
82 let non_null = NonNull::new(new_ptr).ok_or(Error::NullAllocation)?;
83
84 let secure_array = SecureArray {
85 ptr: non_null,
86 _marker: PhantomData,
87 };
88
89 let (encrypted, locked) = secure_array.lock_memory();
90
91 #[cfg(feature = "std")]
92 if !locked {
93 return Err(Error::LockFailed);
94 }
95
96 #[cfg(feature = "std")]
97 if !encrypted {
98 return Err(Error::CryptProtectMemoryFailed);
99 }
100
101 Ok(secure_array)
102 }
103
104 pub fn new(mut content: [T; LENGTH]) -> Result<Self, Error> {
106 let secure_array = Self::empty()?;
107
108 secure_array.unlock_memory();
109
110 unsafe {
111 core::ptr::copy_nonoverlapping(
113 content.as_ptr(),
114 secure_array.ptr.as_ptr(),
115 LENGTH,
116 );
117 }
118
119 content.zeroize();
120
121 let (encrypted, locked) = secure_array.lock_memory();
122
123 #[cfg(feature = "std")]
124 if !locked {
125 return Err(Error::LockFailed);
126 }
127
128 #[cfg(feature = "std")]
129 if !encrypted {
130 return Err(Error::CryptProtectMemoryFailed);
131 }
132
133 Ok(secure_array)
134 }
135
136 pub fn len(&self) -> usize {
137 LENGTH
138 }
139
140 pub fn as_ptr(&self) -> *const T {
141 self.ptr.as_ptr()
142 }
143
144 pub fn as_mut_ptr(&mut self) -> *mut u8 {
145 self.ptr.as_ptr() as *mut u8
146 }
147
148 #[allow(dead_code)]
149 fn allocated_byte_size(&self) -> usize {
150 let size = self.len() * mem::size_of::<T>();
151 #[cfg(feature = "std")]
152 {
153 (size + page_size() - 1) & !(page_size() - 1)
154 }
155 #[cfg(not(feature = "std"))]
156 {
157 size }
159 }
160
161 #[cfg(all(feature = "std", windows))]
162 fn encypt_memory(&self) -> bool {
163 let ptr = self.as_ptr() as *mut u8;
164 super::crypt_protect_memory(ptr, self.allocated_byte_size())
165 }
166
167 #[cfg(all(feature = "std", windows))]
168 fn decrypt_memory(&self) -> bool {
169 let ptr = self.as_ptr() as *mut u8;
170 super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
171 }
172
173 pub(crate) fn lock_memory(&self) -> (bool, bool) {
174 #[cfg(feature = "std")]
175 {
176 #[cfg(windows)]
177 {
178 let encrypt_ok = self.encypt_memory();
179 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
180 (encrypt_ok, mprotect_ok)
181 }
182 #[cfg(unix)]
183 {
184 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
185 (true, mprotect_ok)
186 }
187 }
188 #[cfg(not(feature = "std"))]
189 {
190 (true, true) }
192 }
193
194 pub(crate) fn unlock_memory(&self) -> (bool, bool) {
195 #[cfg(feature = "std")]
196 {
197 #[cfg(windows)]
198 {
199 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
200 if !mprotect_ok {
201 return (false, false);
202 }
203 let decrypt_ok = self.decrypt_memory();
204 (decrypt_ok, mprotect_ok)
205 }
206 #[cfg(unix)]
207 {
208 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
209 (true, mprotect_ok)
210 }
211 }
212
213 #[cfg(not(feature = "std"))]
214 {
215 (true, true) }
217 }
218
219 pub fn unlocked_scope<F, R>(&self, f: F) -> R
221 where
222 F: FnOnce(&[T]) -> R,
223 {
224 self.unlock_memory();
225 let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
226 let result = f(slice);
227 self.lock_memory();
228 result
229 }
230
231 pub fn unlocked_mut_scope<F, R>(&mut self, f: F) -> R
233 where
234 F: FnOnce(&mut [T]) -> R,
235 {
236 self.unlock_memory();
237 let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
238 let result = f(slice);
239 self.lock_memory();
240 result
241 }
242
243 pub fn erase(&mut self) {
245 self.unlocked_mut_scope(|slice| {
246 for element in slice.iter_mut() {
247 element.zeroize();
248 }
249 });
250 }
251}
252
253impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
254 type Output = T;
255 fn index(&self, index: usize) -> &Self::Output {
256 assert!(index < self.len(), "Index out of bounds");
257 unsafe {
258 let ptr = self.ptr.as_ptr().add(index);
259 let reference = &*ptr;
260 reference
261 }
262 }
263}
264
265impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
266 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
267 assert!(index < self.len(), "Index out of bounds");
268 unsafe {
269 let ptr = self.ptr.as_ptr().add(index);
270 let reference = &mut *ptr;
271 reference
272 }
273 }
274}
275
276impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
277 fn drop(&mut self) {
278 self.erase();
279 self.unlock_memory();
280
281 let size = LENGTH * mem::size_of::<T>();
282 if size == 0 {
283 return;
284 }
285
286 unsafe {
287 #[cfg(feature = "std")]
288 {
289 memsec::free(self.ptr);
290 }
291 #[cfg(not(feature = "std"))]
292 {
293 let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
295 dealloc(self.ptr.as_ptr() as *mut u8, layout);
296 }
297 }
298 }
299}
300
301impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
302 fn clone(&self) -> Self {
303 let mut new_array = Self::empty().unwrap();
304 self.unlocked_scope(|src_slice| {
305 new_array.unlocked_mut_scope(|dest_slice| {
306 dest_slice.clone_from_slice(src_slice);
307 });
308 });
309 new_array
310 }
311}
312
313impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
314 type Error = Error;
315
316 fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
321 if vec.len() != LENGTH {
322 return Err(Error::LengthMismatch);
323 }
324
325 let mut new_array = Self::empty()?;
326
327 vec.slice_scope(|vec_slice| {
328 new_array.unlocked_mut_scope(|array_slice| {
329 array_slice.copy_from_slice(vec_slice);
330 });
331 });
332
333 Ok(new_array)
334 }
335}
336
337impl<T, const LENGTH: usize> TryFrom<[T; LENGTH]> for SecureArray<T, LENGTH>
338where
339 T: Zeroize,
340{
341 type Error = Error;
342 fn try_from(s: [T; LENGTH]) -> Result<Self, Error> {
343 Self::new(s)
344 }
345}
346
347#[cfg(feature = "serde")]
348impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
349 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
350 where
351 S: serde::Serializer,
352 {
353 self.unlocked_scope(|slice| serializer.collect_seq(slice.iter()))
354 }
355}
356
357#[cfg(feature = "serde")]
358impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
359 fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
360 where
361 D: serde::Deserializer<'de>,
362 {
363 struct SecureArrayVisitor<const L: usize>;
364
365 impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
366 type Value = SecureArray<u8, L>;
367
368 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
369 write!(formatter, "a byte array of length {}", L)
370 }
371
372 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
373 where
374 A: serde::de::SeqAccess<'de>,
375 {
376 let mut data: SecureVec<u8> =
377 SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
378 while let Some(byte) = seq.next_element()? {
379 data.push(byte);
380 }
381
382 if data.len() != L {
384 return Err(serde::de::Error::invalid_length(
385 data.len(),
386 &self,
387 ));
388 }
389
390 SecureArray::try_from(data).map_err(serde::de::Error::custom)
391 }
392 }
393
394 deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
395 }
396}
397
398#[cfg(all(test, feature = "std"))]
399mod tests {
400 use super::*;
401 use std::process::{Command, Stdio};
402 use std::sync::{Arc, Mutex};
403
404 #[test]
405 fn test_creation() {
406 let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
407 assert_eq!(array.len(), 3);
408 array.unlocked_scope(|slice| {
409 assert_eq!(slice, &[1, 2, 3]);
410 });
411 }
412
413 #[test]
414 fn test_from_secure_vec() {
415 let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
416 let array: SecureArray<u8, 3> = vec.try_into().unwrap();
417 assert_eq!(array.len(), 3);
418 array.unlocked_scope(|slice| {
419 assert_eq!(slice, &[1, 2, 3]);
420 });
421 }
422
423 #[test]
424 fn test_erase() {
425 let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
426 array.erase();
427 array.unlocked_scope(|slice| {
428 assert_eq!(slice, &[0u8; 3]);
429 });
430 }
431
432 #[test]
433 fn test_clone() {
434 let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
435 array1.unlocked_mut_scope(|slice| {
436 slice[0] = 1;
437 slice[1] = 2;
438 slice[2] = 3;
439 });
440
441 let array2 = array1.clone();
442
443 array2.unlocked_scope(|slice| {
444 assert_eq!(slice, &[1, 2, 3]);
445 });
446
447 array1.unlocked_scope(|slice| {
448 assert_eq!(slice, &[1, 2, 3]);
449 });
450 }
451
452 #[test]
453 fn test_thread_safety() {
454 let array = SecureArray::new([1, 2, 3]).unwrap();
455 let arc_array = Arc::new(Mutex::new(array));
456 let mut handles = Vec::new();
457
458 for _ in 0..5u8 {
459 let array_clone = Arc::clone(&arc_array);
460 let handle = std::thread::spawn(move || {
461 let mut guard = array_clone.lock().unwrap();
462 guard.unlocked_mut_scope(|slice| {
463 slice[0] += 1;
464 });
465 });
466 handles.push(handle);
467 }
468
469 for handle in handles {
470 handle.join().unwrap();
471 }
472
473 let final_array = arc_array.lock().unwrap();
474 final_array.unlocked_scope(|slice| {
475 assert_eq!(slice[0], 6);
476 });
477 }
478
479 #[test]
480 fn test_index_should_fail_when_locked() {
481 let arg = "CRASH_TEST_ARRAY_LOCKED";
482
483 if std::env::args().any(|a| a == arg) {
484 let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
485 let _value = core::hint::black_box(array[0]);
486
487 std::process::exit(1);
488 }
489
490 let child = Command::new(std::env::current_exe().unwrap())
491 .arg("array::tests::test_index_should_fail_when_locked")
492 .arg(arg)
493 .arg("--nocapture")
494 .stdout(Stdio::piped())
495 .stderr(Stdio::piped())
496 .spawn()
497 .expect("Failed to spawn child process");
498
499 let output = child.wait_with_output().expect("Failed to wait on child");
500 let status = output.status;
501
502 assert!(
503 !status.success(),
504 "Process exited successfully with code {:?}, but it should have crashed.",
505 status.code()
506 );
507
508 #[cfg(unix)]
509 {
510 use std::os::unix::process::ExitStatusExt;
511 let signal = status
512 .signal()
513 .expect("Process was not terminated by a signal on Unix.");
514 assert!(
515 signal == libc::SIGSEGV || signal == libc::SIGBUS,
516 "Process terminated with unexpected signal: {}",
517 signal
518 );
519 println!(
520 "Test passed: Process correctly terminated with signal {}.",
521 signal
522 );
523 }
524
525 #[cfg(windows)]
526 {
527 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
528 assert_eq!(
529 status.code(),
530 Some(STATUS_ACCESS_VIOLATION),
531 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
532 status.code()
533 );
534 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
535 }
536 }
537
538 #[test]
539 fn test_mutable_access_in_scope() {
540 let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
541
542 array.unlocked_mut_scope(|slice| {
543 slice[1] = 100;
544 });
545
546 array.unlocked_scope(|slice| {
547 assert_eq!(slice, &[1, 100, 3]);
548 });
549 }
550
551 #[cfg(feature = "serde")]
552 #[test]
553 fn test_serde() {
554 let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
555 let json_string = serde_json::to_string(&array).expect("Serialization failed");
556 let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
557
558 let deserialized_string: SecureArray<u8, 3> =
559 serde_json::from_str(&json_string).expect("Deserialization failed");
560
561 let deserialized_bytes: SecureArray<u8, 3> =
562 serde_json::from_slice(&json_bytes).expect("Deserialization failed");
563
564 deserialized_string.unlocked_scope(|slice| {
565 assert_eq!(slice, &[1, 2, 3]);
566 });
567
568 deserialized_bytes.unlocked_scope(|slice| {
569 assert_eq!(slice, &[1, 2, 3]);
570 });
571 }
572}