1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4#[cfg(feature = "std")]
5use std::vec::Vec;
6
7use super::Error;
8use core::{
9 marker::PhantomData,
10 mem,
11 ops::{Bound, RangeBounds},
12 ptr::{self, NonNull},
13};
14use zeroize::Zeroize;
15
16#[cfg(feature = "std")]
17use super::page_size;
18#[cfg(feature = "std")]
19use memsec::Prot;
20
21pub type SecureBytes = SecureVec<u8>;
22
23pub struct SecureVec<T>
73where
74 T: Zeroize,
75{
76 ptr: NonNull<T>,
77 pub(crate) len: usize,
78 pub(crate) capacity: usize,
79 _marker: PhantomData<T>,
80}
81
82unsafe impl<T: Zeroize + Send> Send for SecureVec<T> {}
83unsafe impl<T: Zeroize + Send + Sync> Sync for SecureVec<T> {}
84
85impl<T: Zeroize> SecureVec<T> {
86 pub fn new() -> Result<Self, Error> {
87 let capacity = 1;
89 let size = capacity * mem::size_of::<T>();
90
91 #[cfg(feature = "std")]
92 let ptr = unsafe {
93 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
94 let allocated_ptr = memsec::malloc_sized(aligned_size);
95 let ptr = allocated_ptr.ok_or(Error::AllocationFailed)?;
96 ptr.as_ptr() as *mut T
97 };
98
99 #[cfg(not(feature = "std"))]
100 let ptr = {
101 let layout = Layout::from_size_align(size, mem::align_of::<T>())
102 .map_err(|_| Error::AllocationFailed)?;
103 let ptr = unsafe { alloc::alloc(layout) as *mut T };
104 if ptr.is_null() {
105 return Err(Error::AllocationFailed);
106 }
107 ptr
108 };
109
110 let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
111 let secure = SecureVec {
112 ptr: non_null,
113 len: 0,
114 capacity,
115 _marker: PhantomData,
116 };
117
118 let (encrypted, locked) = secure.lock_memory();
119
120 #[cfg(feature = "std")]
121 if !locked {
122 return Err(Error::LockFailed);
123 }
124
125 #[cfg(feature = "std")]
126 if !encrypted {
127 return Err(Error::CryptProtectMemoryFailed);
128 }
129
130 Ok(secure)
131 }
132
133 pub fn with_capacity(mut capacity: usize) -> Result<Self, Error> {
134 if capacity == 0 {
135 capacity = 1;
136 }
137
138 let size = capacity * mem::size_of::<T>();
139
140 #[cfg(feature = "std")]
141 let ptr = unsafe {
142 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
143 let allocated_ptr = memsec::malloc_sized(aligned_size);
144 let ptr = allocated_ptr.ok_or(Error::AllocationFailed)?;
145 ptr.as_ptr() as *mut T
146 };
147
148 #[cfg(not(feature = "std"))]
149 let ptr = {
150 let layout = Layout::from_size_align(size, mem::align_of::<T>())
151 .map_err(|_| Error::AllocationFailed)?;
152 let ptr = unsafe { alloc::alloc(layout) as *mut T };
153 if ptr.is_null() {
154 return Err(Error::AllocationFailed);
155 }
156 ptr
157 };
158
159 let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
160
161 let secure = SecureVec {
162 ptr: non_null,
163 len: 0,
164 capacity,
165 _marker: PhantomData,
166 };
167
168 let (encrypted, locked) = secure.lock_memory();
169
170 #[cfg(feature = "std")]
171 if !locked {
172 return Err(Error::LockFailed);
173 }
174
175 #[cfg(feature = "std")]
176 if !encrypted {
177 return Err(Error::CryptProtectMemoryFailed);
178 }
179
180 Ok(secure)
181 }
182
183 #[cfg(feature = "std")]
184 pub fn from_vec(mut vec: Vec<T>) -> Result<Self, Error> {
185 if vec.capacity() == 0 {
186 vec.reserve(1);
187 }
188
189 let capacity = vec.capacity();
190 let len = vec.len();
191
192 let size = capacity * mem::size_of::<T>();
194
195 let ptr = unsafe {
196 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
197 let allocated_ptr = memsec::malloc_sized(aligned_size);
198 if allocated_ptr.is_none() {
199 vec.zeroize();
200 return Err(Error::AllocationFailed);
201 } else {
202 allocated_ptr.unwrap().as_ptr() as *mut T
203 }
204 };
205
206 unsafe {
208 core::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, len);
209 }
210
211 vec.zeroize();
213 drop(vec);
214
215 let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
216
217 let secure = SecureVec {
218 ptr: non_null,
219 len,
220 capacity,
221 _marker: PhantomData,
222 };
223
224 let (encrypted, locked) = secure.lock_memory();
225
226 if !locked {
227 return Err(Error::LockFailed);
228 }
229
230 if !encrypted {
231 return Err(Error::CryptProtectMemoryFailed);
232 }
233
234 Ok(secure)
235 }
236
237 pub fn len(&self) -> usize {
238 self.len
239 }
240
241 pub fn as_ptr(&self) -> *const T {
242 self.ptr.as_ptr()
243 }
244
245 pub fn as_mut_ptr(&mut self) -> *mut u8 {
246 self.ptr.as_ptr() as *mut u8
247 }
248
249 #[allow(dead_code)]
250 fn allocated_byte_size(&self) -> usize {
251 let size = self.capacity * mem::size_of::<T>();
252 #[cfg(feature = "std")]
253 {
254 (size + page_size() - 1) & !(page_size() - 1)
255 }
256 #[cfg(not(feature = "std"))]
257 {
258 size }
260 }
261
262 #[cfg(all(feature = "std", windows))]
263 fn encypt_memory(&self) -> bool {
264 let ptr = self.as_ptr() as *mut u8;
265 super::crypt_protect_memory(ptr, self.allocated_byte_size())
266 }
267
268 #[cfg(all(feature = "std", windows))]
269 fn decrypt_memory(&self) -> bool {
270 let ptr = self.as_ptr() as *mut u8;
271 super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
272 }
273
274 pub(crate) fn lock_memory(&self) -> (bool, bool) {
280 #[cfg(feature = "std")]
281 {
282 #[cfg(windows)]
283 {
284 let encrypt_ok = self.encypt_memory();
285 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
286 (encrypt_ok, mprotect_ok)
287 }
288 #[cfg(unix)]
289 {
290 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
291 (true, mprotect_ok)
292 }
293 }
294 #[cfg(not(feature = "std"))]
295 {
296 (true, true) }
298 }
299
300 pub(crate) fn unlock_memory(&self) -> (bool, bool) {
306 #[cfg(feature = "std")]
307 {
308 #[cfg(windows)]
309 {
310 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
311 if !mprotect_ok {
312 return (false, false);
313 }
314 let decrypt_ok = self.decrypt_memory();
315 (decrypt_ok, mprotect_ok)
316 }
317 #[cfg(unix)]
318 {
319 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
320 (true, mprotect_ok)
321 }
322 }
323
324 #[cfg(not(feature = "std"))]
325 {
326 (true, true) }
328 }
329
330 pub fn unlock_scope<F, R>(&self, f: F) -> R
332 where
333 F: FnOnce(&SecureVec<T>) -> R,
334 {
335 self.unlock_memory();
336 let result = f(self);
337 self.lock_memory();
338 result
339 }
340
341 pub fn slice_scope<F, R>(&self, f: F) -> R
343 where
344 F: FnOnce(&[T]) -> R,
345 {
346 unsafe {
347 self.unlock_memory();
348 let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
349 let result = f(slice);
350 self.lock_memory();
351 result
352 }
353 }
354
355 pub fn slice_mut_scope<F, R>(&mut self, f: F) -> R
357 where
358 F: FnOnce(&mut [T]) -> R,
359 {
360 unsafe {
361 self.unlock_memory();
362 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
363 let result = f(slice);
364 self.lock_memory();
365 result
366 }
367 }
368
369 pub fn iter_scope<F, R>(&self, f: F) -> R
377 where
378 F: FnOnce(core::slice::Iter<T>) -> R,
379 {
380 unsafe {
381 self.unlock_memory();
382 let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
383 let iter = slice.iter();
384 let result = f(iter);
385 self.lock_memory();
386 result
387 }
388 }
389
390 pub fn iter_mut_scope<F, R>(&mut self, f: F) -> R
398 where
399 F: FnOnce(core::slice::IterMut<T>) -> R,
400 {
401 unsafe {
402 self.unlock_memory();
403 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
404 let iter = slice.iter_mut();
405 let result = f(iter);
406 self.lock_memory();
407 result
408 }
409 }
410
411 pub fn erase(&mut self) {
415 unsafe {
416 self.unlock_memory();
417 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
418 for elem in slice.iter_mut() {
419 elem.zeroize();
420 }
421 self.clear();
422 self.lock_memory();
423 }
424 }
425
426 pub fn clear(&mut self) {
430 self.len = 0;
431 }
432
433 pub fn push(&mut self, value: T) {
434 if self.len >= self.capacity {
435 let new_capacity = if self.capacity == 0 {
437 1
438 } else {
439 self.capacity + 1
440 };
441
442 let items_byte_size = new_capacity * mem::size_of::<T>();
443
444 #[cfg(feature = "std")]
446 let new_ptr = unsafe {
447 let aligned_allocation_size = (items_byte_size + page_size() - 1) & !(page_size() - 1);
448 memsec::malloc_sized(aligned_allocation_size)
449 .expect("Failed to allocate memory")
450 .as_ptr() as *mut T
451 };
452
453 #[cfg(not(feature = "std"))]
454 let new_ptr = {
455 let layout = Layout::from_size_align(items_byte_size, mem::align_of::<T>())
456 .expect("Failed to allocate memory");
457 let ptr = unsafe { alloc::alloc(layout) as *mut T };
458 if ptr.is_null() {
459 panic!("Null pointer returned from alloc");
460 }
461 ptr
462 };
463
464 unsafe {
466 self.unlock_memory();
467 core::ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr, self.len);
468 if self.capacity > 0 {
469 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
470 for elem in slice.iter_mut() {
471 elem.zeroize();
472 }
473 }
474 #[cfg(feature = "std")]
475 memsec::free(self.ptr);
476
477 #[cfg(not(feature = "std"))]
478 {
479 let old_size = self.capacity * mem::size_of::<T>();
480 let old_layout = Layout::from_size_align_unchecked(old_size, mem::align_of::<T>());
481 dealloc(self.ptr.as_ptr() as *mut u8, old_layout);
482 }
483 }
484
485 self.ptr = NonNull::new(new_ptr).expect("Failed to create NonNull");
487 self.capacity = new_capacity;
488
489 unsafe {
491 core::ptr::write(self.ptr.as_ptr().add(self.len), value);
492 self.len += 1;
493 self.lock_memory();
494 }
495 } else {
496 unsafe {
498 self.unlock_memory();
499 core::ptr::write(self.ptr.as_ptr().add(self.len), value);
500 self.len += 1;
501 self.lock_memory();
502 }
503 }
504 }
505
506 pub fn drain<R>(&mut self, range: R) -> Drain<'_, T>
516 where
517 R: RangeBounds<usize>,
518 {
519 self.unlock_memory();
520
521 let original_len = self.len;
522
523 let (drain_start_idx, drain_end_idx) = resolve_range_indices(range, original_len);
524
525 let tail_len = original_len - drain_end_idx;
526
527 self.len = drain_start_idx;
528
529 Drain {
530 vec_ref: self,
531 drain_start_index: drain_start_idx,
532 current_drain_iter_index: drain_start_idx,
533 drain_end_index: drain_end_idx,
534 original_vec_len: original_len,
535 tail_len,
536 _marker: PhantomData,
537 }
538 }
539}
540
541impl<T: Clone + Zeroize> Clone for SecureVec<T> {
542 fn clone(&self) -> Self {
543 let mut new_vec: SecureVec<T> = SecureVec::with_capacity(self.capacity).unwrap();
544
545 new_vec.unlock_memory();
546 self.unlock_memory();
547
548 unsafe {
549 for i in 0..self.len {
550 let value = (*self.ptr.as_ptr().add(i)).clone();
551 core::ptr::write(new_vec.ptr.as_ptr().add(i), value);
552 }
553 }
554
555 new_vec.len = self.len;
556 self.lock_memory();
557 new_vec.lock_memory();
558 new_vec
559 }
560}
561
562impl<T: Zeroize> Drop for SecureVec<T> {
563 fn drop(&mut self) {
564 self.erase();
565 self.unlock_memory();
566 unsafe {
567 #[cfg(feature = "std")]
568 memsec::free(self.ptr);
569
570 #[cfg(not(feature = "std"))]
571 {
572 let layout =
574 Layout::from_size_align_unchecked(self.allocated_byte_size(), mem::align_of::<T>());
575 dealloc(self.ptr.as_ptr() as *mut u8, layout);
576 }
577 }
578 }
579}
580
581impl<T: Zeroize> core::ops::Index<usize> for SecureVec<T> {
582 type Output = T;
583
584 fn index(&self, index: usize) -> &Self::Output {
585 assert!(index < self.len, "Index out of bounds");
586 unsafe {
587 let ptr = self.ptr.as_ptr().add(index);
588 let reference = &*ptr;
589 reference
590 }
591 }
592}
593
594#[cfg(feature = "serde")]
595impl serde::Serialize for SecureVec<u8> {
596 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
597 where
598 S: serde::Serializer,
599 {
600 let res = self.slice_scope(|slice| serializer.collect_seq(slice.iter()));
601 res
602 }
603}
604
605#[cfg(feature = "serde")]
606impl<'de> serde::Deserialize<'de> for SecureVec<u8> {
607 fn deserialize<D>(deserializer: D) -> Result<SecureVec<u8>, D::Error>
608 where
609 D: serde::Deserializer<'de>,
610 {
611 struct SecureVecVisitor;
612 impl<'de> serde::de::Visitor<'de> for SecureVecVisitor {
613 type Value = SecureVec<u8>;
614 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
615 write!(formatter, "a sequence of bytes")
616 }
617 fn visit_seq<A>(
618 self,
619 mut seq: A,
620 ) -> Result<<Self as serde::de::Visitor<'de>>::Value, A::Error>
621 where
622 A: serde::de::SeqAccess<'de>,
623 {
624 let mut vec = Vec::new();
625 while let Some(byte) = seq.next_element::<u8>()? {
626 vec.push(byte);
627 }
628 SecureVec::from_vec(vec).map_err(serde::de::Error::custom)
629 }
630 }
631 deserializer.deserialize_seq(SecureVecVisitor)
632 }
633}
634
635pub struct Drain<'a, T: Zeroize + 'a> {
648 vec_ref: &'a mut SecureVec<T>,
649 drain_start_index: usize,
650 current_drain_iter_index: usize,
651 drain_end_index: usize,
652
653 original_vec_len: usize, tail_len: usize, _marker: PhantomData<&'a T>,
657}
658
659impl<'a, T: Zeroize> Iterator for Drain<'a, T> {
660 type Item = T;
661
662 fn next(&mut self) -> Option<T> {
663 if self.current_drain_iter_index < self.drain_end_index {
664 unsafe {
666 let item_ptr = self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index);
667 let item = ptr::read(item_ptr);
668 self.current_drain_iter_index += 1;
669 Some(item)
670 }
671 } else {
672 None
673 }
674 }
675
676 fn size_hint(&self) -> (usize, Option<usize>) {
677 let remaining = self.drain_end_index - self.current_drain_iter_index;
678 (remaining, Some(remaining))
679 }
680}
681
682impl<'a, T: Zeroize> ExactSizeIterator for Drain<'a, T> {}
683
684impl<'a, T: Zeroize> Drop for Drain<'a, T> {
685 fn drop(&mut self) {
686 unsafe {
687 if mem::needs_drop::<T>() {
689 let mut current_ptr =
690 self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index) as *mut T;
691 let end_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index) as *mut T;
692 while current_ptr < end_ptr {
693 ptr::drop_in_place(current_ptr);
694 current_ptr = current_ptr.add(1);
695 }
696 }
697
698 let hole_dst_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_start_index) as *mut T;
699 let tail_src_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index) as *mut T;
700
701 if self.tail_len > 0 {
702 ptr::copy(tail_src_ptr, hole_dst_ptr, self.tail_len);
703 }
704
705 let new_len = self.drain_start_index + self.tail_len;
707
708 let mut current_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(new_len) as *mut T;
718 let end_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(self.original_vec_len) as *mut T;
719
720 let original_tail_start_ptr_val = tail_src_ptr as usize;
722
723 while current_cleanup_ptr < end_cleanup_ptr {
724 if mem::needs_drop::<T>() {
725 let current_ptr_val = current_cleanup_ptr as usize;
726 let original_tail_end_ptr_val =
727 original_tail_start_ptr_val + self.tail_len * mem::size_of::<T>();
728
729 if current_ptr_val >= original_tail_start_ptr_val
730 && current_ptr_val < original_tail_end_ptr_val
731 {
732 ptr::drop_in_place(current_cleanup_ptr);
735 }
736 }
739
740 (*current_cleanup_ptr).zeroize();
742 current_cleanup_ptr = current_cleanup_ptr.add(1);
743 }
744
745 self.vec_ref.len = new_len;
747
748 self.vec_ref.lock_memory();
750 }
751 }
752}
753
754fn resolve_range_indices<R: RangeBounds<usize>>(range: R, len: usize) -> (usize, usize) {
756 let start_bound = range.start_bound();
757 let end_bound = range.end_bound();
758
759 let start = match start_bound {
760 Bound::Included(&s) => s,
761 Bound::Excluded(&s) => s
762 .checked_add(1)
763 .unwrap_or_else(|| panic!("attempted to start drain at Excluded(usize::MAX)")),
764 Bound::Unbounded => 0,
765 };
766
767 let end = match end_bound {
768 Bound::Included(&e) => e
769 .checked_add(1)
770 .unwrap_or_else(|| panic!("attempted to end drain at Included(usize::MAX)")),
771 Bound::Excluded(&e) => e,
772 Bound::Unbounded => len,
773 };
774
775 if start > end {
776 panic!(
777 "drain range start ({}) must be less than or equal to end ({})",
778 start, end
779 );
780 }
781 if end > len {
782 panic!(
783 "drain range end ({}) out of bounds for slice of length {}",
784 end, len
785 );
786 }
787
788 (start, end)
789}
790
791#[cfg(all(test, feature = "std"))]
792mod tests {
793 use super::*;
794 use std::process::{Command, Stdio};
795 use std::sync::{Arc, Mutex};
796
797 #[test]
798 fn test_creation() {
799 let vec: Vec<u8> = vec![1, 2, 3];
800 let _ = SecureVec::from_vec(vec);
801 let _: SecureVec<u8> = SecureVec::new().unwrap();
802 let _: SecureVec<u8> = SecureVec::with_capacity(3).unwrap();
803 }
804
805 #[test]
806 fn lock_unlock() {
807 let secure: SecureVec<u8> = SecureVec::new().unwrap();
808 let size = secure.allocated_byte_size();
809 assert_eq!(size > 0, true);
810
811 let (decrypted, unlocked) = secure.unlock_memory();
812 assert!(decrypted);
813 assert!(unlocked);
814
815 let (encrypted, locked) = secure.lock_memory();
816 assert!(encrypted);
817 assert!(locked);
818
819 let secure: SecureVec<u8> = SecureVec::from_vec(vec![]).unwrap();
820 let size = secure.allocated_byte_size();
821 assert_eq!(size > 0, true);
822
823 let (decrypted, unlocked) = secure.unlock_memory();
824 assert!(decrypted);
825 assert!(unlocked);
826
827 let (encrypted, locked) = secure.lock_memory();
828 assert!(encrypted);
829 assert!(locked);
830
831 let secure: SecureVec<u8> = SecureVec::with_capacity(0).unwrap();
832 let size = secure.allocated_byte_size();
833 assert_eq!(size > 0, true);
834
835 let (decrypted, unlocked) = secure.unlock_memory();
836 assert!(decrypted);
837 assert!(unlocked);
838
839 let (encrypted, locked) = secure.lock_memory();
840 assert!(encrypted);
841 assert!(locked);
842 }
843
844 #[test]
845 fn test_thread_safety() {
846 let vec: Vec<u8> = vec![];
847 let secure = SecureVec::from_vec(vec).unwrap();
848 let secure = Arc::new(Mutex::new(secure));
849
850 let mut handles = Vec::new();
851 for i in 0..5u8 {
852 let secure_clone = secure.clone();
853 let handle = std::thread::spawn(move || {
854 let mut secure = secure_clone.lock().unwrap();
855 secure.push(i);
856 });
857 handles.push(handle);
858 }
859
860 for handle in handles {
861 handle.join().unwrap();
862 }
863
864 let sec = secure.lock().unwrap();
865 sec.slice_scope(|slice| {
866 assert_eq!(slice.len(), 5);
867 });
868 }
869
870 #[test]
871 fn test_clone() {
872 let vec: Vec<u8> = vec![1, 2, 3];
873 let secure1 = SecureVec::from_vec(vec).unwrap();
874 let _secure2 = secure1.clone();
875 }
876
877 #[test]
878 fn test_do_not_call_forget_on_drain() {
879 let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
880 let mut secure = SecureVec::from_vec(vec).unwrap();
881 let drain = secure.drain(..3);
882 core::mem::forget(drain);
883 secure.slice_scope(|secure| {
885 assert_eq!(secure.len(), 0);
886 });
887 }
888
889 #[test]
890 fn test_drain() {
891 let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
892 let mut secure = SecureVec::from_vec(vec).unwrap();
893 let mut drain = secure.drain(..3);
894 assert_eq!(drain.next(), Some(1));
895 assert_eq!(drain.next(), Some(2));
896 assert_eq!(drain.next(), Some(3));
897 assert_eq!(drain.next(), None);
898 drop(drain);
899 secure.slice_scope(|secure| {
900 assert_eq!(secure.len(), 7);
901 assert_eq!(secure, &[4, 5, 6, 7, 8, 9, 10]);
902 });
903 }
904
905 #[cfg(feature = "serde")]
906 #[test]
907 fn test_secure_vec_serde() {
908 let vec: Vec<u8> = vec![1, 2, 3];
909 let secure = SecureVec::from_vec(vec).unwrap();
910 let json = serde_json::to_vec(&secure).expect("Serialization failed");
911 let deserialized: SecureVec<u8> =
912 serde_json::from_slice(&json).expect("Deserialization failed");
913 deserialized.slice_scope(|slice| {
914 assert_eq!(slice, &[1, 2, 3]);
915 });
916 }
917
918 #[test]
919 fn test_erase() {
920 let vec: Vec<u8> = vec![1, 2, 3];
921 let mut secure = SecureVec::from_vec(vec).unwrap();
922 secure.erase();
923 secure.unlock_scope(|secure| {
924 assert_eq!(secure.len, 0);
925 assert_eq!(secure.capacity, 3);
926 });
927
928 secure.push(1);
929 secure.push(2);
930 secure.push(3);
931 secure.unlock_scope(|secure| {
932 assert_eq!(secure[0], 1);
933 assert_eq!(secure[1], 2);
934 assert_eq!(secure[2], 3);
935 assert_eq!(secure.len, 3);
936 assert_eq!(secure.capacity, 3);
937 });
938 }
939
940 #[test]
941 fn test_push() {
942 let vec: Vec<u8> = Vec::new();
943 let mut secure = SecureVec::from_vec(vec).unwrap();
944 for i in 0..30 {
945 secure.push(i);
946 }
947 }
948
949 #[test]
950 fn test_index() {
951 let vec: Vec<u8> = vec![1, 2, 3];
952 let secure = SecureVec::from_vec(vec).unwrap();
953 secure.unlock_scope(|secure| {
954 assert_eq!(secure[0], 1);
955 assert_eq!(secure[1], 2);
956 assert_eq!(secure[2], 3);
957 });
958 }
959
960 #[test]
961 fn test_slice_scoped() {
962 let vec: Vec<u8> = vec![1, 2, 3];
963 let secure = SecureVec::from_vec(vec).unwrap();
964 secure.slice_scope(|slice| {
965 assert_eq!(slice, &[1, 2, 3]);
966 });
967 }
968
969 #[test]
970 fn test_slice_mut_scoped() {
971 let vec: Vec<u8> = vec![1, 2, 3];
972 let mut secure = SecureVec::from_vec(vec).unwrap();
973
974 secure.slice_mut_scope(|slice| {
975 slice[0] = 4;
976 assert_eq!(slice, &mut [4, 2, 3]);
977 });
978
979 secure.slice_scope(|slice| {
980 assert_eq!(slice, &[4, 2, 3]);
981 });
982 }
983
984 #[test]
985 fn test_iter_scoped() {
986 let vec: Vec<u8> = vec![1, 2, 3];
987 let secure = SecureVec::from_vec(vec).unwrap();
988 let sum: u8 = secure.iter_scope(|iter| iter.map(|&x| x).sum());
989
990 assert_eq!(sum, 6);
991
992 let secure: SecureVec<u8> = SecureVec::with_capacity(3).unwrap();
993 let sum: u8 = secure.iter_scope(|iter| iter.map(|&x| x).sum());
994
995 assert_eq!(sum, 0);
996 }
997
998 #[test]
999 fn test_iter_mut_scoped() {
1000 let vec: Vec<u8> = vec![1, 2, 3];
1001 let mut secure = SecureVec::from_vec(vec).unwrap();
1002 secure.iter_mut_scope(|iter| {
1003 for elem in iter {
1004 *elem += 1;
1005 }
1006 });
1007
1008 secure.slice_scope(|slice| {
1009 assert_eq!(slice, &[2, 3, 4]);
1010 });
1011 }
1012
1013 #[test]
1014 fn test_index_should_fail_when_locked() {
1015 let arg = "CRASH_TEST_SECUREVEC_LOCKED";
1016
1017 if std::env::args().any(|a| a == arg) {
1018 let vec: Vec<u8> = vec![1, 2, 3];
1019 let secure = SecureVec::from_vec(vec).unwrap();
1020 let _value = core::hint::black_box(secure[0]);
1021
1022 std::process::exit(1);
1023 }
1024
1025 let child = Command::new(std::env::current_exe().unwrap())
1026 .arg("vec::tests::test_index_should_fail_when_locked")
1027 .arg(arg)
1028 .arg("--nocapture")
1029 .stdout(Stdio::piped())
1030 .stderr(Stdio::piped())
1031 .spawn()
1032 .expect("Failed to spawn child process");
1033
1034 let output = child.wait_with_output().expect("Failed to wait on child");
1035 let status = output.status;
1036
1037 assert!(
1038 !status.success(),
1039 "Process exited successfully with code {:?}, but it should have crashed.",
1040 status.code()
1041 );
1042
1043 #[cfg(unix)]
1044 {
1045 use std::os::unix::process::ExitStatusExt;
1046 let signal = status
1047 .signal()
1048 .expect("Process was not terminated by a signal on Unix.");
1049 assert!(
1050 signal == libc::SIGSEGV || signal == libc::SIGBUS,
1051 "Process terminated with unexpected signal: {}",
1052 signal
1053 );
1054 println!(
1055 "Test passed: Process correctly terminated with signal {}.",
1056 signal
1057 );
1058 }
1059
1060 #[cfg(windows)]
1061 {
1062 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
1063 assert_eq!(
1064 status.code(),
1065 Some(STATUS_ACCESS_VIOLATION),
1066 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
1067 status.code()
1068 );
1069 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
1070 }
1071 }
1072}