1#[cfg(not(feature = "use_os"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::{Error, SecureVec};
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "use_os")]
9use super::{alloc, free};
10#[cfg(feature = "use_os")]
11use memsec::Prot;
12
13pub struct SecureArray<T, const LENGTH: usize>
57where
58 T: Zeroize,
59{
60 ptr: NonNull<T>,
61 _marker: PhantomData<T>,
62}
63
64unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
65unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
66
67impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
68where
69 T: Zeroize,
70{
71 pub fn empty() -> Result<Self, Error> {
75 let size = LENGTH * mem::size_of::<T>();
76 if size == 0 {
77 return Err(Error::LengthCannotBeZero);
79 }
80
81 let ptr = unsafe { alloc::<T>(size)? };
82
83 let secure_array = SecureArray {
84 ptr,
85 _marker: PhantomData,
86 };
87
88 let locked = secure_array.lock_memory();
89
90 #[cfg(feature = "use_os")]
91 if !locked {
92 return Err(Error::LockFailed);
93 }
94
95 Ok(secure_array)
96 }
97
98 pub fn from_slice_mut(content: &mut [T; LENGTH]) -> Result<Self, Error> {
102 let secure_array = Self::empty()?;
103
104 let unlocked = secure_array.unlock_memory();
105
106 if !unlocked {
107 return Err(Error::UnlockFailed);
108 }
109
110 unsafe {
111 core::ptr::copy_nonoverlapping(
113 content.as_ptr(),
114 secure_array.ptr.as_ptr(),
115 LENGTH,
116 );
117 }
118
119 content.zeroize();
120
121 let locked = secure_array.lock_memory();
122
123 #[cfg(feature = "use_os")]
124 if !locked {
125 return Err(Error::LockFailed);
126 }
127
128 Ok(secure_array)
129 }
130
131 pub fn from_slice(content: &[T; LENGTH]) -> Result<Self, Error> {
135 let secure_array = Self::empty()?;
136
137 let unlocked = secure_array.unlock_memory();
138
139 if !unlocked {
140 return Err(Error::UnlockFailed);
141 }
142
143 unsafe {
144 core::ptr::copy_nonoverlapping(
146 content.as_ptr(),
147 secure_array.ptr.as_ptr(),
148 LENGTH,
149 );
150 }
151
152 let locked = secure_array.lock_memory();
153
154 #[cfg(feature = "use_os")]
155 if !locked {
156 return Err(Error::LockFailed);
157 }
158
159 Ok(secure_array)
160 }
161
162 pub fn len(&self) -> usize {
163 LENGTH
164 }
165
166 pub fn is_empty(&self) -> bool {
167 self.len() == 0
168 }
169
170 pub(crate) fn lock_memory(&self) -> bool {
171 #[cfg(feature = "use_os")]
172 {
173 #[cfg(windows)]
174 {
175 super::mprotect(self.ptr, Prot::NoAccess)
176 }
177 #[cfg(unix)]
178 {
179 super::mprotect(self.ptr, Prot::NoAccess)
180 }
181 }
182 #[cfg(not(feature = "use_os"))]
183 {
184 true }
186 }
187
188 pub(crate) fn unlock_memory(&self) -> bool {
189 #[cfg(feature = "use_os")]
190 {
191 #[cfg(windows)]
192 {
193 super::mprotect(self.ptr, Prot::ReadWrite)
194 }
195 #[cfg(unix)]
196 {
197 super::mprotect(self.ptr, Prot::ReadWrite)
198 }
199 }
200
201 #[cfg(not(feature = "use_os"))]
202 {
203 true }
205 }
206
207 pub fn unlock<F, R>(&self, f: F) -> R
209 where
210 F: FnOnce(&[T]) -> R,
211 {
212 self.unlock_memory();
213 let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
214 let result = f(slice);
215 self.lock_memory();
216 result
217 }
218
219 pub fn unlock_mut<F, R>(&mut self, f: F) -> R
221 where
222 F: FnOnce(&mut [T]) -> R,
223 {
224 self.unlock_memory();
225 let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
226 let result = f(slice);
227 self.lock_memory();
228 result
229 }
230
231 pub fn erase(&mut self) {
233 self.unlock_mut(|slice| {
234 for element in slice.iter_mut() {
235 element.zeroize();
236 }
237 });
238 }
239}
240
241impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
242 type Output = T;
243 fn index(&self, index: usize) -> &Self::Output {
244 assert!(index < self.len(), "Index out of bounds");
245 unsafe {
246 let ptr = self.ptr.as_ptr().add(index);
247 &*ptr
248 }
249 }
250}
251
252impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
253 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
254 assert!(index < self.len(), "Index out of bounds");
255 unsafe {
256 let ptr = self.ptr.as_ptr().add(index);
257 &mut *ptr
258 }
259 }
260}
261
262impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
263 fn drop(&mut self) {
264 self.erase();
265 self.unlock_memory();
266
267 let size = LENGTH * mem::size_of::<T>();
268 if size == 0 {
269 return;
270 }
271
272 #[cfg(feature = "use_os")]
273 free(self.ptr);
274
275 #[cfg(not(feature = "use_os"))]
276 {
277 let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
278 dealloc(self.ptr.as_ptr() as *mut u8, layout);
279 }
280 }
281}
282
283impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
284 fn clone(&self) -> Self {
285 let mut new_array = Self::empty().unwrap();
286 self.unlock(|src_slice| {
287 new_array.unlock_mut(|dest_slice| {
288 dest_slice.clone_from_slice(src_slice);
289 });
290 });
291 new_array
292 }
293}
294
295impl<const LENGTH: usize> TryFrom<SecureVec<u8>> for SecureArray<u8, LENGTH> {
296 type Error = Error;
297
298 fn try_from(vec: SecureVec<u8>) -> Result<Self, Self::Error> {
304 if vec.len() != LENGTH {
305 return Err(Error::LengthMismatch);
306 }
307
308 let mut new_array = Self::empty()?;
309
310 vec.unlock_slice(|vec_slice| {
311 new_array.unlock_mut(|array_slice| {
312 array_slice.copy_from_slice(vec_slice);
313 });
314 });
315
316 Ok(new_array)
317 }
318}
319
320#[cfg(feature = "serde")]
321impl<const LENGTH: usize> serde::Serialize for SecureArray<u8, LENGTH> {
322 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
323 where
324 S: serde::Serializer,
325 {
326 self.unlock(|slice| serializer.collect_seq(slice.iter()))
327 }
328}
329
330#[cfg(feature = "serde")]
331impl<'de, const LENGTH: usize> serde::Deserialize<'de> for SecureArray<u8, LENGTH> {
332 fn deserialize<D>(deserializer: D) -> Result<SecureArray<u8, LENGTH>, D::Error>
333 where
334 D: serde::Deserializer<'de>,
335 {
336 struct SecureArrayVisitor<const L: usize>;
337
338 impl<'de, const L: usize> serde::de::Visitor<'de> for SecureArrayVisitor<L> {
339 type Value = SecureArray<u8, L>;
340
341 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
342 write!(formatter, "a byte array of length {}", L)
343 }
344
345 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
346 where
347 A: serde::de::SeqAccess<'de>,
348 {
349 let mut data: SecureVec<u8> =
350 SecureVec::new_with_capacity(L).map_err(serde::de::Error::custom)?;
351 while let Some(byte) = seq.next_element()? {
352 data.push(byte);
353 }
354
355 if data.len() != L {
357 return Err(serde::de::Error::invalid_length(
358 data.len(),
359 &self,
360 ));
361 }
362
363 SecureArray::try_from(data).map_err(serde::de::Error::custom)
364 }
365 }
366
367 deserializer.deserialize_bytes(SecureArrayVisitor::<LENGTH>)
368 }
369}
370
371#[cfg(all(test, feature = "use_os"))]
372mod tests {
373 use super::*;
374 use std::process::{Command, Stdio};
375 use std::sync::{Arc, Mutex};
376
377 #[test]
378 fn test_creation() {
379 let exposed_mut = &mut [1, 2, 3];
380 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed_mut).unwrap();
381 assert_eq!(array.len(), 3);
382
383 array.unlock(|slice| {
384 assert_eq!(slice, &[1, 2, 3]);
385 });
386
387 assert_eq!(exposed_mut, &[0u8; 3]);
388
389 let exposed = &[1, 2, 3];
390 let array: SecureArray<u8, 3> = SecureArray::from_slice(exposed).unwrap();
391 assert_eq!(array.len(), 3);
392
393 array.unlock(|slice| {
394 assert_eq!(slice, &[1, 2, 3]);
395 });
396
397 assert_eq!(exposed, &[1, 2, 3]);
398 }
399
400 #[test]
401 fn test_from_secure_vec() {
402 let vec: SecureVec<u8> = SecureVec::from_slice(&[1, 2, 3]).unwrap();
403 let array: SecureArray<u8, 3> = vec.try_into().unwrap();
404 assert_eq!(array.len(), 3);
405 array.unlock(|slice| {
406 assert_eq!(slice, &[1, 2, 3]);
407 });
408 }
409
410 #[test]
411 fn test_erase() {
412 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
413 let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
414 array.erase();
415 array.unlock(|slice| {
416 assert_eq!(slice, &[0u8; 3]);
417 });
418 }
419
420 #[test]
421 #[should_panic]
422 fn test_length_cannot_be_zero() {
423 let secure_vec = SecureVec::new().unwrap();
424 let _secure_array: SecureArray<u8, 0> = SecureArray::try_from(secure_vec).unwrap();
425 }
426
427 #[test]
428 fn lock_unlock() {
429 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
430 let secure: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
431
432 let unlocked = secure.unlock_memory();
433 assert!(unlocked);
434
435 let locked = secure.lock_memory();
436 assert!(locked);
437 }
438
439 #[test]
440 fn test_clone() {
441 let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
442 array1.unlock_mut(|slice| {
443 slice[0] = 1;
444 slice[1] = 2;
445 slice[2] = 3;
446 });
447
448 let array2 = array1.clone();
449
450 array2.unlock(|slice| {
451 assert_eq!(slice, &[1, 2, 3]);
452 });
453
454 array1.unlock(|slice| {
455 assert_eq!(slice, &[1, 2, 3]);
456 });
457 }
458
459 #[test]
460 fn test_thread_safety() {
461 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
462 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
463 let arc_array = Arc::new(Mutex::new(array));
464 let mut handles = Vec::new();
465
466 for _ in 0..5u8 {
467 let array_clone = Arc::clone(&arc_array);
468 let handle = std::thread::spawn(move || {
469 let mut guard = array_clone.lock().unwrap();
470 guard.unlock_mut(|slice| {
471 slice[0] += 1;
472 });
473 });
474 handles.push(handle);
475 }
476
477 for handle in handles {
478 handle.join().unwrap();
479 }
480
481 let final_array = arc_array.lock().unwrap();
482 final_array.unlock(|slice| {
483 assert_eq!(slice[0], 6);
484 assert_eq!(slice[1], 2);
485 assert_eq!(slice[2], 3);
486 });
487 }
488
489 #[test]
490 fn test_index_should_fail_when_locked() {
491 let arg = "CRASH_TEST_ARRAY_LOCKED";
492
493 if std::env::args().any(|a| a == arg) {
494 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
495 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
496 let _value = core::hint::black_box(array[0]);
497
498 std::process::exit(1);
499 }
500
501 let child = Command::new(std::env::current_exe().unwrap())
502 .arg("array::tests::test_index_should_fail_when_locked")
503 .arg(arg)
504 .arg("--nocapture")
505 .stdout(Stdio::piped())
506 .stderr(Stdio::piped())
507 .spawn()
508 .expect("Failed to spawn child process");
509
510 let output = child.wait_with_output().expect("Failed to wait on child");
511 let status = output.status;
512
513 assert!(
514 !status.success(),
515 "Process exited successfully with code {:?}, but it should have crashed.",
516 status.code()
517 );
518
519 #[cfg(unix)]
520 {
521 use std::os::unix::process::ExitStatusExt;
522 let signal = status
523 .signal()
524 .expect("Process was not terminated by a signal on Unix.");
525 assert!(
526 signal == libc::SIGSEGV || signal == libc::SIGBUS,
527 "Process terminated with unexpected signal: {}",
528 signal
529 );
530 println!(
531 "Test passed: Process correctly terminated with signal {}.",
532 signal
533 );
534 }
535
536 #[cfg(windows)]
537 {
538 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
539 assert_eq!(
540 status.code(),
541 Some(STATUS_ACCESS_VIOLATION),
542 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
543 status.code()
544 );
545 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
546 }
547 }
548
549 #[test]
550 fn test_unlock_mut() {
551 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
552 let mut array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
553
554 array.unlock_mut(|slice| {
555 slice[1] = 100;
556 });
557
558 array.unlock(|slice| {
559 assert_eq!(slice, &[1, 100, 3]);
560 });
561 }
562
563 #[cfg(feature = "serde")]
564 #[test]
565 fn test_serde() {
566 let exposed: &mut [u8; 3] = &mut [1, 2, 3];
567 let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
568 let json_string = serde_json::to_string(&array).expect("Serialization failed");
569 let json_bytes = serde_json::to_vec(&array).expect("Serialization failed");
570
571 let deserialized_string: SecureArray<u8, 3> =
572 serde_json::from_str(&json_string).expect("Deserialization failed");
573
574 let deserialized_bytes: SecureArray<u8, 3> =
575 serde_json::from_slice(&json_bytes).expect("Deserialization failed");
576
577 deserialized_string.unlock(|slice| {
578 assert_eq!(slice, &[1, 2, 3]);
579 });
580
581 deserialized_bytes.unlock(|slice| {
582 assert_eq!(slice, &[1, 2, 3]);
583 });
584 }
585}