1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4use super::Error;
5use core::{marker::PhantomData, mem, ptr::NonNull};
6use zeroize::Zeroize;
7
8#[cfg(feature = "std")]
9use super::page_size;
10#[cfg(feature = "std")]
11use memsec::Prot;
12
13pub struct SecureArray<T, const LENGTH: usize>
41where
42 T: Zeroize,
43{
44 ptr: NonNull<T>,
45 _marker: PhantomData<T>,
46}
47
48unsafe impl<T: Zeroize + Send, const LENGTH: usize> Send for SecureArray<T, LENGTH> {}
49unsafe impl<T: Zeroize + Send + Sync, const LENGTH: usize> Sync for SecureArray<T, LENGTH> {}
50
51impl<T, const LENGTH: usize> SecureArray<T, LENGTH>
52where
53 T: Zeroize,
54{
55 pub fn empty() -> Result<Self, Error> {
58 let size = LENGTH * mem::size_of::<T>();
59 if size == 0 {
60 return Err(Error::AllocationFailed);
62 }
63
64 #[cfg(feature = "std")]
65 let new_ptr = {
66 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
67 let allocated_ptr = unsafe { memsec::malloc_sized(aligned_size) };
68 allocated_ptr.ok_or(Error::AllocationFailed)?.as_ptr() as *mut T
69 };
70
71 #[cfg(not(feature = "std"))]
72 let new_ptr = {
73 let layout = Layout::from_size_align(size, mem::align_of::<T>())
74 .map_err(|_| Error::AllocationFailed)?;
75 let ptr = unsafe { alloc::alloc(layout) as *mut T };
76 if ptr.is_null() {
77 return Err(Error::AllocationFailed);
78 }
79 ptr
80 };
81
82 let non_null = NonNull::new(new_ptr).ok_or(Error::NullAllocation)?;
83
84 let secure_array = SecureArray {
85 ptr: non_null,
86 _marker: PhantomData,
87 };
88
89 let (encrypted, locked) = secure_array.lock_memory();
90
91 #[cfg(feature = "std")]
92 if !locked {
93 return Err(Error::LockFailed);
94 }
95
96 #[cfg(feature = "std")]
97 if !encrypted {
98 return Err(Error::CryptProtectMemoryFailed);
99 }
100
101 Ok(secure_array)
102 }
103
104 pub fn new(mut content: [T; LENGTH]) -> Result<Self, Error> {
106 let secure_array = Self::empty()?;
107
108 secure_array.unlock_memory();
109
110 unsafe {
111 core::ptr::copy_nonoverlapping(
113 content.as_ptr(),
114 secure_array.ptr.as_ptr(),
115 LENGTH,
116 );
117 }
118
119 content.zeroize();
120
121 let (encrypted, locked) = secure_array.lock_memory();
122
123 #[cfg(feature = "std")]
124 if !locked {
125 return Err(Error::LockFailed);
126 }
127
128 #[cfg(feature = "std")]
129 if !encrypted {
130 return Err(Error::CryptProtectMemoryFailed);
131 }
132
133 Ok(secure_array)
134 }
135
136 pub fn len(&self) -> usize {
137 LENGTH
138 }
139
140 pub fn as_ptr(&self) -> *const T {
141 self.ptr.as_ptr()
142 }
143
144 pub fn as_mut_ptr(&mut self) -> *mut u8 {
145 self.ptr.as_ptr() as *mut u8
146 }
147
148 #[allow(dead_code)]
149 fn allocated_byte_size(&self) -> usize {
150 let size = self.len() * mem::size_of::<T>();
151 #[cfg(feature = "std")]
152 {
153 (size + page_size() - 1) & !(page_size() - 1)
154 }
155 #[cfg(not(feature = "std"))]
156 {
157 size }
159 }
160
161 #[cfg(all(feature = "std", windows))]
162 fn encypt_memory(&self) -> bool {
163 let ptr = self.as_ptr() as *mut u8;
164 super::crypt_protect_memory(ptr, self.allocated_byte_size())
165 }
166
167 #[cfg(all(feature = "std", windows))]
168 fn decrypt_memory(&self) -> bool {
169 let ptr = self.as_ptr() as *mut u8;
170 super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
171 }
172
173 pub(crate) fn lock_memory(&self) -> (bool, bool) {
174 #[cfg(feature = "std")]
175 {
176 #[cfg(windows)]
177 {
178 let encrypt_ok = self.encypt_memory();
179 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
180 (encrypt_ok, mprotect_ok)
181 }
182 #[cfg(unix)]
183 {
184 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
185 (true, mprotect_ok)
186 }
187 }
188 #[cfg(not(feature = "std"))]
189 {
190 (true, true) }
192 }
193
194 pub(crate) fn unlock_memory(&self) -> (bool, bool) {
195 #[cfg(feature = "std")]
196 {
197 #[cfg(windows)]
198 {
199 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
200 if !mprotect_ok {
201 return (false, false);
202 }
203 let decrypt_ok = self.decrypt_memory();
204 (decrypt_ok, mprotect_ok)
205 }
206 #[cfg(unix)]
207 {
208 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
209 (true, mprotect_ok)
210 }
211 }
212
213 #[cfg(not(feature = "std"))]
214 {
215 (true, true) }
217 }
218
219 pub fn unlocked_scope<F, R>(&self, f: F) -> R
221 where
222 F: FnOnce(&[T]) -> R,
223 {
224 self.unlock_memory();
225 let slice = unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), LENGTH) };
226 let result = f(slice);
227 self.lock_memory();
228 result
229 }
230
231 pub fn unlocked_mut_scope<F, R>(&mut self, f: F) -> R
233 where
234 F: FnOnce(&mut [T]) -> R,
235 {
236 self.unlock_memory();
237 let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), LENGTH) };
238 let result = f(slice);
239 self.lock_memory();
240 result
241 }
242
243 pub fn erase(&mut self) {
245 self.unlocked_mut_scope(|slice| {
246 for element in slice.iter_mut() {
247 element.zeroize();
248 }
249 });
250 }
251}
252
253impl<T: Zeroize, const LENGTH: usize> core::ops::Index<usize> for SecureArray<T, LENGTH> {
254 type Output = T;
255 fn index(&self, index: usize) -> &Self::Output {
256 assert!(index < self.len(), "Index out of bounds");
257 unsafe {
258 let ptr = self.ptr.as_ptr().add(index);
259 let reference = &*ptr;
260 reference
261 }
262 }
263}
264
265impl<T: Zeroize, const LENGTH: usize> core::ops::IndexMut<usize> for SecureArray<T, LENGTH> {
266 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
267 assert!(index < self.len(), "Index out of bounds");
268 unsafe {
269 let ptr = self.ptr.as_ptr().add(index);
270 let reference = &mut *ptr;
271 reference
272 }
273 }
274}
275
276impl<T: Zeroize, const LENGTH: usize> Drop for SecureArray<T, LENGTH> {
277 fn drop(&mut self) {
278 self.erase();
279 self.unlock_memory();
280
281 let size = LENGTH * mem::size_of::<T>();
282 if size == 0 {
283 return;
284 }
285
286 unsafe {
287 #[cfg(feature = "std")]
288 {
289 memsec::free(self.ptr);
290 }
291 #[cfg(not(feature = "std"))]
292 {
293 let layout = Layout::from_size_align_unchecked(size, mem::align_of::<T>());
295 dealloc(self.ptr.as_ptr() as *mut u8, layout);
296 }
297 }
298 }
299}
300
301impl<T: Clone + Zeroize, const LENGTH: usize> Clone for SecureArray<T, LENGTH> {
302 fn clone(&self) -> Self {
303 let mut new_array = Self::empty().unwrap();
304 self.unlocked_scope(|src_slice| {
305 new_array.unlocked_mut_scope(|dest_slice| {
306 dest_slice.clone_from_slice(src_slice);
307 });
308 });
309 new_array
310 }
311}
312
313impl<T, const LENGTH: usize> TryFrom<[T; LENGTH]> for SecureArray<T, LENGTH>
314where
315 T: Zeroize,
316{
317 type Error = Error;
318 fn try_from(s: [T; LENGTH]) -> Result<Self, Error> {
319 Self::new(s)
320 }
321}
322
323#[cfg(all(test, feature = "std"))]
324mod tests {
325 use super::*;
326 use std::process::{Command, Stdio};
327 use std::sync::{Arc, Mutex};
328
329 #[test]
330 fn test_creation() {
331 let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
332 assert_eq!(array.len(), 3);
333 array.unlocked_scope(|slice| {
334 assert_eq!(slice, &[1, 2, 3]);
335 });
336 }
337
338 #[test]
339 fn test_erase() {
340 let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
341 array.erase();
342 array.unlocked_scope(|slice| {
343 assert_eq!(slice, &[0u8; 3]);
344 });
345 }
346
347 #[test]
348 fn test_clone() {
349 let mut array1: SecureArray<u8, 3> = SecureArray::empty().unwrap();
350 array1.unlocked_mut_scope(|slice| {
351 slice[0] = 1;
352 slice[1] = 2;
353 slice[2] = 3;
354 });
355
356 let array2 = array1.clone();
357
358 array2.unlocked_scope(|slice| {
359 assert_eq!(slice, &[1, 2, 3]);
360 });
361
362 array1.unlocked_scope(|slice| {
363 assert_eq!(slice, &[1, 2, 3]);
364 });
365 }
366
367 #[test]
368 fn test_thread_safety() {
369 let array = SecureArray::new([1, 2, 3]).unwrap();
370 let arc_array = Arc::new(Mutex::new(array));
371 let mut handles = Vec::new();
372
373 for _ in 0..5u8 {
374 let array_clone = Arc::clone(&arc_array);
375 let handle = std::thread::spawn(move || {
376 let mut guard = array_clone.lock().unwrap();
377 guard.unlocked_mut_scope(|slice| {
378 slice[0] += 1;
379 });
380 });
381 handles.push(handle);
382 }
383
384 for handle in handles {
385 handle.join().unwrap();
386 }
387
388 let final_array = arc_array.lock().unwrap();
389 final_array.unlocked_scope(|slice| {
390 assert_eq!(slice[0], 6);
391 });
392 }
393
394 #[test]
395 fn test_index_should_fail_when_locked() {
396 let arg = "CRASH_TEST_ARRAY_LOCKED";
397
398 if std::env::args().any(|a| a == arg) {
399 let array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
400 let _value = core::hint::black_box(array[0]);
401
402 std::process::exit(1);
403 }
404
405 let child = Command::new(std::env::current_exe().unwrap())
406 .arg("array::tests::test_index_should_fail_when_locked")
407 .arg(arg)
408 .arg("--nocapture")
409 .stdout(Stdio::piped())
410 .stderr(Stdio::piped())
411 .spawn()
412 .expect("Failed to spawn child process");
413
414 let output = child.wait_with_output().expect("Failed to wait on child");
415 let status = output.status;
416
417 assert!(
418 !status.success(),
419 "Process exited successfully with code {:?}, but it should have crashed.",
420 status.code()
421 );
422
423 #[cfg(unix)]
424 {
425 use std::os::unix::process::ExitStatusExt;
426 let signal = status
427 .signal()
428 .expect("Process was not terminated by a signal on Unix.");
429 assert!(
430 signal == libc::SIGSEGV || signal == libc::SIGBUS,
431 "Process terminated with unexpected signal: {}",
432 signal
433 );
434 println!(
435 "Test passed: Process correctly terminated with signal {}.",
436 signal
437 );
438 }
439
440 #[cfg(windows)]
441 {
442 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
443 assert_eq!(
444 status.code(),
445 Some(STATUS_ACCESS_VIOLATION),
446 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
447 status.code()
448 );
449 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
450 }
451 }
452
453 #[test]
454 fn test_mutable_access_in_scope() {
455 let mut array: SecureArray<u8, 3> = SecureArray::new([1, 2, 3]).unwrap();
456
457 array.unlocked_mut_scope(|slice| {
458 slice[1] = 100;
459 });
460
461 array.unlocked_scope(|slice| {
462 assert_eq!(slice, &[1, 100, 3]);
463 });
464 }
465}