1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4#[cfg(feature = "std")]
5use std::vec::Vec;
6
7use super::Error;
8use core::{
9 marker::PhantomData,
10 mem,
11 ops::{Bound, RangeBounds},
12 ptr::{self, NonNull},
13};
14use zeroize::{DefaultIsZeroes, Zeroize};
15
16#[cfg(feature = "std")]
17use super::page_size;
18#[cfg(feature = "std")]
19use memsec::Prot;
20
21pub type SecureBytes = SecureVec<u8>;
22
23pub struct SecureVec<T>
73where
74 T: Zeroize,
75{
76 ptr: NonNull<T>,
77 pub(crate) len: usize,
78 pub(crate) capacity: usize,
79 _marker: PhantomData<T>,
80}
81
82unsafe impl<T: Zeroize + Send> Send for SecureVec<T> {}
83unsafe impl<T: Zeroize + Send + Sync> Sync for SecureVec<T> {}
84
85impl<T: Zeroize> SecureVec<T> {
86 pub fn new() -> Result<Self, Error> {
87 let capacity = 1;
89 let size = capacity * mem::size_of::<T>();
90
91 #[cfg(feature = "std")]
92 let ptr = unsafe {
93 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
94 let allocated_ptr = memsec::malloc_sized(aligned_size);
95 let ptr = allocated_ptr.ok_or(Error::AllocationFailed)?;
96 ptr.as_ptr() as *mut T
97 };
98
99 #[cfg(not(feature = "std"))]
100 let ptr = {
101 let layout = Layout::from_size_align(size, mem::align_of::<T>())
102 .map_err(|_| Error::AllocationFailed)?;
103 let ptr = unsafe { alloc::alloc(layout) as *mut T };
104 if ptr.is_null() {
105 return Err(Error::AllocationFailed);
106 }
107 ptr
108 };
109
110 let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
111 let secure = SecureVec {
112 ptr: non_null,
113 len: 0,
114 capacity,
115 _marker: PhantomData,
116 };
117
118 let (encrypted, locked) = secure.lock_memory();
119
120 #[cfg(feature = "std")]
121 if !locked {
122 return Err(Error::LockFailed);
123 }
124
125 #[cfg(feature = "std")]
126 if !encrypted {
127 return Err(Error::CryptProtectMemoryFailed);
128 }
129
130 Ok(secure)
131 }
132
133 pub fn new_with_capacity(mut capacity: usize) -> Result<Self, Error> {
134 if capacity == 0 {
135 capacity = 1;
136 }
137
138 let size = capacity * mem::size_of::<T>();
139
140 #[cfg(feature = "std")]
141 let ptr = unsafe {
142 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
143 let allocated_ptr = memsec::malloc_sized(aligned_size);
144 let ptr = allocated_ptr.ok_or(Error::AllocationFailed)?;
145 ptr.as_ptr() as *mut T
146 };
147
148 #[cfg(not(feature = "std"))]
149 let ptr = {
150 let layout = Layout::from_size_align(size, mem::align_of::<T>())
151 .map_err(|_| Error::AllocationFailed)?;
152 let ptr = unsafe { alloc::alloc(layout) as *mut T };
153 if ptr.is_null() {
154 return Err(Error::AllocationFailed);
155 }
156 ptr
157 };
158
159 let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
160
161 let secure = SecureVec {
162 ptr: non_null,
163 len: 0,
164 capacity,
165 _marker: PhantomData,
166 };
167
168 let (encrypted, locked) = secure.lock_memory();
169
170 #[cfg(feature = "std")]
171 if !locked {
172 return Err(Error::LockFailed);
173 }
174
175 #[cfg(feature = "std")]
176 if !encrypted {
177 return Err(Error::CryptProtectMemoryFailed);
178 }
179
180 Ok(secure)
181 }
182
183 #[cfg(feature = "std")]
184 pub fn from_vec(mut vec: Vec<T>) -> Result<Self, Error> {
185 if vec.capacity() == 0 {
186 vec.reserve(1);
187 }
188
189 let capacity = vec.capacity();
190 let len = vec.len();
191
192 let size = capacity * mem::size_of::<T>();
194
195 let ptr = unsafe {
196 let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
197 let allocated_ptr = memsec::malloc_sized(aligned_size);
198 if allocated_ptr.is_none() {
199 vec.zeroize();
200 return Err(Error::AllocationFailed);
201 } else {
202 allocated_ptr.unwrap().as_ptr() as *mut T
203 }
204 };
205
206 unsafe {
208 core::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, len);
209 }
210
211 vec.zeroize();
213 drop(vec);
214
215 let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
216
217 let secure = SecureVec {
218 ptr: non_null,
219 len,
220 capacity,
221 _marker: PhantomData,
222 };
223
224 let (encrypted, locked) = secure.lock_memory();
225
226 if !locked {
227 return Err(Error::LockFailed);
228 }
229
230 if !encrypted {
231 return Err(Error::CryptProtectMemoryFailed);
232 }
233
234 Ok(secure)
235 }
236
237 pub fn from_slice_mut(slice: &mut [T]) -> Result<Self, Error>
240 where
241 T: Clone + DefaultIsZeroes,
242 {
243 let mut secure_vec = SecureVec::new_with_capacity(slice.len())?;
244 secure_vec.len = slice.len();
245 secure_vec.slice_mut_scope(|dest_slice| {
246 dest_slice.clone_from_slice(slice);
247 });
248 slice.zeroize();
249 Ok(secure_vec)
250 }
251
252 pub fn from_slice(slice: &[T]) -> Result<Self, Error>
255 where
256 T: Clone,
257 {
258 let mut secure_vec = SecureVec::new_with_capacity(slice.len())?;
259 secure_vec.len = slice.len();
260 secure_vec.slice_mut_scope(|dest_slice| {
261 dest_slice.clone_from_slice(slice);
262 });
263 Ok(secure_vec)
264 }
265
266 pub fn len(&self) -> usize {
267 self.len
268 }
269
270 pub fn as_ptr(&self) -> *const T {
271 self.ptr.as_ptr()
272 }
273
274 pub fn as_mut_ptr(&mut self) -> *mut u8 {
275 self.ptr.as_ptr() as *mut u8
276 }
277
278 #[allow(dead_code)]
279 fn allocated_byte_size(&self) -> usize {
280 let size = self.capacity * mem::size_of::<T>();
281 #[cfg(feature = "std")]
282 {
283 (size + page_size() - 1) & !(page_size() - 1)
284 }
285 #[cfg(not(feature = "std"))]
286 {
287 size }
289 }
290
291 #[cfg(all(feature = "std", windows))]
292 fn encypt_memory(&self) -> bool {
293 let ptr = self.as_ptr() as *mut u8;
294 super::crypt_protect_memory(ptr, self.allocated_byte_size())
295 }
296
297 #[cfg(all(feature = "std", windows))]
298 fn decrypt_memory(&self) -> bool {
299 let ptr = self.as_ptr() as *mut u8;
300 super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
301 }
302
303 pub(crate) fn lock_memory(&self) -> (bool, bool) {
309 #[cfg(feature = "std")]
310 {
311 #[cfg(windows)]
312 {
313 let encrypt_ok = self.encypt_memory();
314 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
315 (encrypt_ok, mprotect_ok)
316 }
317 #[cfg(unix)]
318 {
319 let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
320 (true, mprotect_ok)
321 }
322 }
323 #[cfg(not(feature = "std"))]
324 {
325 (true, true) }
327 }
328
329 pub(crate) fn unlock_memory(&self) -> (bool, bool) {
335 #[cfg(feature = "std")]
336 {
337 #[cfg(windows)]
338 {
339 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
340 if !mprotect_ok {
341 return (false, false);
342 }
343 let decrypt_ok = self.decrypt_memory();
344 (decrypt_ok, mprotect_ok)
345 }
346 #[cfg(unix)]
347 {
348 let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
349 (true, mprotect_ok)
350 }
351 }
352
353 #[cfg(not(feature = "std"))]
354 {
355 (true, true) }
357 }
358
359 pub fn unlock_scope<F, R>(&self, f: F) -> R
361 where
362 F: FnOnce(&SecureVec<T>) -> R,
363 {
364 self.unlock_memory();
365 let result = f(self);
366 self.lock_memory();
367 result
368 }
369
370 pub fn slice_scope<F, R>(&self, f: F) -> R
372 where
373 F: FnOnce(&[T]) -> R,
374 {
375 unsafe {
376 self.unlock_memory();
377 let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
378 let result = f(slice);
379 self.lock_memory();
380 result
381 }
382 }
383
384 pub fn slice_mut_scope<F, R>(&mut self, f: F) -> R
386 where
387 F: FnOnce(&mut [T]) -> R,
388 {
389 unsafe {
390 self.unlock_memory();
391 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
392 let result = f(slice);
393 self.lock_memory();
394 result
395 }
396 }
397
398 pub fn iter_scope<F, R>(&self, f: F) -> R
406 where
407 F: FnOnce(core::slice::Iter<T>) -> R,
408 {
409 unsafe {
410 self.unlock_memory();
411 let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
412 let iter = slice.iter();
413 let result = f(iter);
414 self.lock_memory();
415 result
416 }
417 }
418
419 pub fn iter_mut_scope<F, R>(&mut self, f: F) -> R
427 where
428 F: FnOnce(core::slice::IterMut<T>) -> R,
429 {
430 unsafe {
431 self.unlock_memory();
432 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
433 let iter = slice.iter_mut();
434 let result = f(iter);
435 self.lock_memory();
436 result
437 }
438 }
439
440 pub fn erase(&mut self) {
444 unsafe {
445 self.unlock_memory();
446 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
447 for elem in slice.iter_mut() {
448 elem.zeroize();
449 }
450 self.clear();
451 self.lock_memory();
452 }
453 }
454
455 pub fn clear(&mut self) {
459 self.len = 0;
460 }
461
462 pub fn push(&mut self, value: T) {
463 if self.len >= self.capacity {
464 let new_capacity = if self.capacity == 0 {
466 1
467 } else {
468 self.capacity + 1
469 };
470
471 let items_byte_size = new_capacity * mem::size_of::<T>();
472
473 #[cfg(feature = "std")]
475 let new_ptr = unsafe {
476 let aligned_allocation_size = (items_byte_size + page_size() - 1) & !(page_size() - 1);
477 memsec::malloc_sized(aligned_allocation_size)
478 .expect("Failed to allocate memory")
479 .as_ptr() as *mut T
480 };
481
482 #[cfg(not(feature = "std"))]
483 let new_ptr = {
484 let layout = Layout::from_size_align(items_byte_size, mem::align_of::<T>())
485 .expect("Failed to allocate memory");
486 let ptr = unsafe { alloc::alloc(layout) as *mut T };
487 if ptr.is_null() {
488 panic!("Null pointer returned from alloc");
489 }
490 ptr
491 };
492
493 unsafe {
495 self.unlock_memory();
496 core::ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr, self.len);
497 if self.capacity > 0 {
498 let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
499 for elem in slice.iter_mut() {
500 elem.zeroize();
501 }
502 }
503 #[cfg(feature = "std")]
504 memsec::free(self.ptr);
505
506 #[cfg(not(feature = "std"))]
507 {
508 let old_size = self.capacity * mem::size_of::<T>();
509 let old_layout = Layout::from_size_align_unchecked(old_size, mem::align_of::<T>());
510 dealloc(self.ptr.as_ptr() as *mut u8, old_layout);
511 }
512 }
513
514 self.ptr = NonNull::new(new_ptr).expect("Failed to create NonNull");
516 self.capacity = new_capacity;
517
518 unsafe {
520 core::ptr::write(self.ptr.as_ptr().add(self.len), value);
521 self.len += 1;
522 self.lock_memory();
523 }
524 } else {
525 unsafe {
527 self.unlock_memory();
528 core::ptr::write(self.ptr.as_ptr().add(self.len), value);
529 self.len += 1;
530 self.lock_memory();
531 }
532 }
533 }
534
535 pub fn drain<R>(&mut self, range: R) -> Drain<'_, T>
545 where
546 R: RangeBounds<usize>,
547 {
548 self.unlock_memory();
549
550 let original_len = self.len;
551
552 let (drain_start_idx, drain_end_idx) = resolve_range_indices(range, original_len);
553
554 let tail_len = original_len - drain_end_idx;
555
556 self.len = drain_start_idx;
557
558 Drain {
559 vec_ref: self,
560 drain_start_index: drain_start_idx,
561 current_drain_iter_index: drain_start_idx,
562 drain_end_index: drain_end_idx,
563 original_vec_len: original_len,
564 tail_len,
565 _marker: PhantomData,
566 }
567 }
568}
569
570impl<T: Clone + Zeroize> Clone for SecureVec<T> {
571 fn clone(&self) -> Self {
572 let mut new_vec: SecureVec<T> = SecureVec::new_with_capacity(self.capacity).unwrap();
573
574 new_vec.unlock_memory();
575 self.unlock_memory();
576
577 unsafe {
578 for i in 0..self.len {
579 let value = (*self.ptr.as_ptr().add(i)).clone();
580 core::ptr::write(new_vec.ptr.as_ptr().add(i), value);
581 }
582 }
583
584 new_vec.len = self.len;
585 self.lock_memory();
586 new_vec.lock_memory();
587 new_vec
588 }
589}
590
591impl<T: Zeroize> Drop for SecureVec<T> {
592 fn drop(&mut self) {
593 self.erase();
594 self.unlock_memory();
595 unsafe {
596 #[cfg(feature = "std")]
597 memsec::free(self.ptr);
598
599 #[cfg(not(feature = "std"))]
600 {
601 let layout =
603 Layout::from_size_align_unchecked(self.allocated_byte_size(), mem::align_of::<T>());
604 dealloc(self.ptr.as_ptr() as *mut u8, layout);
605 }
606 }
607 }
608}
609
610impl<T: Zeroize> core::ops::Index<usize> for SecureVec<T> {
611 type Output = T;
612
613 fn index(&self, index: usize) -> &Self::Output {
614 assert!(index < self.len, "Index out of bounds");
615 unsafe {
616 let ptr = self.ptr.as_ptr().add(index);
617 let reference = &*ptr;
618 reference
619 }
620 }
621}
622
623#[cfg(feature = "serde")]
624impl serde::Serialize for SecureVec<u8> {
625 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
626 where
627 S: serde::Serializer,
628 {
629 let res = self.slice_scope(|slice| serializer.collect_seq(slice.iter()));
630 res
631 }
632}
633
634#[cfg(feature = "serde")]
635impl<'de> serde::Deserialize<'de> for SecureVec<u8> {
636 fn deserialize<D>(deserializer: D) -> Result<SecureVec<u8>, D::Error>
637 where
638 D: serde::Deserializer<'de>,
639 {
640 struct SecureVecVisitor;
641 impl<'de> serde::de::Visitor<'de> for SecureVecVisitor {
642 type Value = SecureVec<u8>;
643 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
644 write!(formatter, "a sequence of bytes")
645 }
646 fn visit_seq<A>(
647 self,
648 mut seq: A,
649 ) -> Result<<Self as serde::de::Visitor<'de>>::Value, A::Error>
650 where
651 A: serde::de::SeqAccess<'de>,
652 {
653 let mut vec = Vec::new();
654 while let Some(byte) = seq.next_element::<u8>()? {
655 vec.push(byte);
656 }
657 SecureVec::from_vec(vec).map_err(serde::de::Error::custom)
658 }
659 }
660 deserializer.deserialize_seq(SecureVecVisitor)
661 }
662}
663
664pub struct Drain<'a, T: Zeroize + 'a> {
677 vec_ref: &'a mut SecureVec<T>,
678 drain_start_index: usize,
679 current_drain_iter_index: usize,
680 drain_end_index: usize,
681
682 original_vec_len: usize, tail_len: usize, _marker: PhantomData<&'a T>,
686}
687
688impl<'a, T: Zeroize> Iterator for Drain<'a, T> {
689 type Item = T;
690
691 fn next(&mut self) -> Option<T> {
692 if self.current_drain_iter_index < self.drain_end_index {
693 unsafe {
695 let item_ptr = self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index);
696 let item = ptr::read(item_ptr);
697 self.current_drain_iter_index += 1;
698 Some(item)
699 }
700 } else {
701 None
702 }
703 }
704
705 fn size_hint(&self) -> (usize, Option<usize>) {
706 let remaining = self.drain_end_index - self.current_drain_iter_index;
707 (remaining, Some(remaining))
708 }
709}
710
711impl<'a, T: Zeroize> ExactSizeIterator for Drain<'a, T> {}
712
713impl<'a, T: Zeroize> Drop for Drain<'a, T> {
714 fn drop(&mut self) {
715 unsafe {
716 if mem::needs_drop::<T>() {
718 let mut current_ptr =
719 self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index) as *mut T;
720 let end_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index) as *mut T;
721 while current_ptr < end_ptr {
722 ptr::drop_in_place(current_ptr);
723 current_ptr = current_ptr.add(1);
724 }
725 }
726
727 let hole_dst_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_start_index) as *mut T;
728 let tail_src_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index) as *mut T;
729
730 if self.tail_len > 0 {
731 ptr::copy(tail_src_ptr, hole_dst_ptr, self.tail_len);
732 }
733
734 let new_len = self.drain_start_index + self.tail_len;
736
737 let mut current_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(new_len) as *mut T;
747 let end_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(self.original_vec_len) as *mut T;
748
749 let original_tail_start_ptr_val = tail_src_ptr as usize;
751
752 while current_cleanup_ptr < end_cleanup_ptr {
753 if mem::needs_drop::<T>() {
754 let current_ptr_val = current_cleanup_ptr as usize;
755 let original_tail_end_ptr_val =
756 original_tail_start_ptr_val + self.tail_len * mem::size_of::<T>();
757
758 if current_ptr_val >= original_tail_start_ptr_val
759 && current_ptr_val < original_tail_end_ptr_val
760 {
761 ptr::drop_in_place(current_cleanup_ptr);
764 }
765 }
768
769 (*current_cleanup_ptr).zeroize();
771 current_cleanup_ptr = current_cleanup_ptr.add(1);
772 }
773
774 self.vec_ref.len = new_len;
776
777 self.vec_ref.lock_memory();
779 }
780 }
781}
782
783fn resolve_range_indices<R: RangeBounds<usize>>(range: R, len: usize) -> (usize, usize) {
785 let start_bound = range.start_bound();
786 let end_bound = range.end_bound();
787
788 let start = match start_bound {
789 Bound::Included(&s) => s,
790 Bound::Excluded(&s) => s
791 .checked_add(1)
792 .unwrap_or_else(|| panic!("attempted to start drain at Excluded(usize::MAX)")),
793 Bound::Unbounded => 0,
794 };
795
796 let end = match end_bound {
797 Bound::Included(&e) => e
798 .checked_add(1)
799 .unwrap_or_else(|| panic!("attempted to end drain at Included(usize::MAX)")),
800 Bound::Excluded(&e) => e,
801 Bound::Unbounded => len,
802 };
803
804 if start > end {
805 panic!(
806 "drain range start ({}) must be less than or equal to end ({})",
807 start, end
808 );
809 }
810 if end > len {
811 panic!(
812 "drain range end ({}) out of bounds for slice of length {}",
813 end, len
814 );
815 }
816
817 (start, end)
818}
819
820#[cfg(all(test, feature = "std"))]
821mod tests {
822 use super::*;
823 use std::process::{Command, Stdio};
824 use std::sync::{Arc, Mutex};
825
826 #[test]
827 fn test_from_methods() {
828 let vec: Vec<u8> = vec![1, 2, 3];
829 let secure_vec = SecureVec::from_vec(vec).unwrap();
830
831 secure_vec.slice_scope(|slice| {
832 assert_eq!(slice, &[1, 2, 3]);
833 });
834
835 let mut slice = [3u8, 5];
836 let secure_slice = SecureVec::from_slice_mut(&mut slice).unwrap();
837 assert_eq!(slice, [0, 0]);
838
839 secure_slice.slice_scope(|slice| {
840 assert_eq!(slice, &[3, 5]);
841 });
842
843 let slice = [3u8, 5];
844 let secure_slice = SecureVec::from_slice(&slice).unwrap();
845
846 secure_slice.slice_scope(|slice| {
847 assert_eq!(slice, &[3, 5]);
848 });
849 }
850
851 #[test]
852 fn lock_unlock() {
853 let secure: SecureVec<u8> = SecureVec::new().unwrap();
854 let size = secure.allocated_byte_size();
855 assert_eq!(size > 0, true);
856
857 let (decrypted, unlocked) = secure.unlock_memory();
858 assert!(decrypted);
859 assert!(unlocked);
860
861 let (encrypted, locked) = secure.lock_memory();
862 assert!(encrypted);
863 assert!(locked);
864
865 let secure: SecureVec<u8> = SecureVec::from_vec(vec![]).unwrap();
866 let size = secure.allocated_byte_size();
867 assert_eq!(size > 0, true);
868
869 let (decrypted, unlocked) = secure.unlock_memory();
870 assert!(decrypted);
871 assert!(unlocked);
872
873 let (encrypted, locked) = secure.lock_memory();
874 assert!(encrypted);
875 assert!(locked);
876
877 let secure: SecureVec<u8> = SecureVec::new_with_capacity(0).unwrap();
878 let size = secure.allocated_byte_size();
879 assert_eq!(size > 0, true);
880
881 let (decrypted, unlocked) = secure.unlock_memory();
882 assert!(decrypted);
883 assert!(unlocked);
884
885 let (encrypted, locked) = secure.lock_memory();
886 assert!(encrypted);
887 assert!(locked);
888 }
889
890 #[test]
891 fn test_thread_safety() {
892 let vec: Vec<u8> = vec![];
893 let secure = SecureVec::from_vec(vec).unwrap();
894 let secure = Arc::new(Mutex::new(secure));
895
896 let mut handles = Vec::new();
897 for i in 0..5u8 {
898 let secure_clone = secure.clone();
899 let handle = std::thread::spawn(move || {
900 let mut secure = secure_clone.lock().unwrap();
901 secure.push(i);
902 });
903 handles.push(handle);
904 }
905
906 for handle in handles {
907 handle.join().unwrap();
908 }
909
910 let sec = secure.lock().unwrap();
911 sec.slice_scope(|slice| {
912 assert_eq!(slice.len(), 5);
913 });
914 }
915
916 #[test]
917 fn test_clone() {
918 let vec: Vec<u8> = vec![1, 2, 3];
919 let secure1 = SecureVec::from_vec(vec).unwrap();
920 let _secure2 = secure1.clone();
921 }
922
923 #[test]
924 fn test_do_not_call_forget_on_drain() {
925 let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
926 let mut secure = SecureVec::from_vec(vec).unwrap();
927 let drain = secure.drain(..3);
928 core::mem::forget(drain);
929 secure.slice_scope(|secure| {
931 assert_eq!(secure.len(), 0);
932 });
933 }
934
935 #[test]
936 fn test_drain() {
937 let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
938 let mut secure = SecureVec::from_vec(vec).unwrap();
939 let mut drain = secure.drain(..3);
940 assert_eq!(drain.next(), Some(1));
941 assert_eq!(drain.next(), Some(2));
942 assert_eq!(drain.next(), Some(3));
943 assert_eq!(drain.next(), None);
944 drop(drain);
945 secure.slice_scope(|secure| {
946 assert_eq!(secure.len(), 7);
947 assert_eq!(secure, &[4, 5, 6, 7, 8, 9, 10]);
948 });
949 }
950
951 #[cfg(feature = "serde")]
952 #[test]
953 fn test_secure_vec_serde() {
954 let vec: Vec<u8> = vec![1, 2, 3];
955 let secure = SecureVec::from_vec(vec).unwrap();
956 let json = serde_json::to_vec(&secure).expect("Serialization failed");
957 let deserialized: SecureVec<u8> =
958 serde_json::from_slice(&json).expect("Deserialization failed");
959 deserialized.slice_scope(|slice| {
960 assert_eq!(slice, &[1, 2, 3]);
961 });
962 }
963
964 #[test]
965 fn test_erase() {
966 let vec: Vec<u8> = vec![1, 2, 3];
967 let mut secure = SecureVec::from_vec(vec).unwrap();
968 secure.erase();
969 secure.unlock_scope(|secure| {
970 assert_eq!(secure.len, 0);
971 assert_eq!(secure.capacity, 3);
972 });
973
974 secure.push(1);
975 secure.push(2);
976 secure.push(3);
977 secure.unlock_scope(|secure| {
978 assert_eq!(secure[0], 1);
979 assert_eq!(secure[1], 2);
980 assert_eq!(secure[2], 3);
981 assert_eq!(secure.len, 3);
982 assert_eq!(secure.capacity, 3);
983 });
984 }
985
986 #[test]
987 fn test_push() {
988 let vec: Vec<u8> = Vec::new();
989 let mut secure = SecureVec::from_vec(vec).unwrap();
990 for i in 0..30 {
991 secure.push(i);
992 }
993 }
994
995 #[test]
996 fn test_index() {
997 let vec: Vec<u8> = vec![1, 2, 3];
998 let secure = SecureVec::from_vec(vec).unwrap();
999 secure.unlock_scope(|secure| {
1000 assert_eq!(secure[0], 1);
1001 assert_eq!(secure[1], 2);
1002 assert_eq!(secure[2], 3);
1003 });
1004 }
1005
1006 #[test]
1007 fn test_slice_scoped() {
1008 let vec: Vec<u8> = vec![1, 2, 3];
1009 let secure = SecureVec::from_vec(vec).unwrap();
1010 secure.slice_scope(|slice| {
1011 assert_eq!(slice, &[1, 2, 3]);
1012 });
1013 }
1014
1015 #[test]
1016 fn test_slice_mut_scoped() {
1017 let vec: Vec<u8> = vec![1, 2, 3];
1018 let mut secure = SecureVec::from_vec(vec).unwrap();
1019
1020 secure.slice_mut_scope(|slice| {
1021 slice[0] = 4;
1022 assert_eq!(slice, &mut [4, 2, 3]);
1023 });
1024
1025 secure.slice_scope(|slice| {
1026 assert_eq!(slice, &[4, 2, 3]);
1027 });
1028 }
1029
1030 #[test]
1031 fn test_iter_scoped() {
1032 let vec: Vec<u8> = vec![1, 2, 3];
1033 let secure = SecureVec::from_vec(vec).unwrap();
1034 let sum: u8 = secure.iter_scope(|iter| iter.map(|&x| x).sum());
1035
1036 assert_eq!(sum, 6);
1037
1038 let secure: SecureVec<u8> = SecureVec::new_with_capacity(3).unwrap();
1039 let sum: u8 = secure.iter_scope(|iter| iter.map(|&x| x).sum());
1040
1041 assert_eq!(sum, 0);
1042 }
1043
1044 #[test]
1045 fn test_iter_mut_scoped() {
1046 let vec: Vec<u8> = vec![1, 2, 3];
1047 let mut secure = SecureVec::from_vec(vec).unwrap();
1048 secure.iter_mut_scope(|iter| {
1049 for elem in iter {
1050 *elem += 1;
1051 }
1052 });
1053
1054 secure.slice_scope(|slice| {
1055 assert_eq!(slice, &[2, 3, 4]);
1056 });
1057 }
1058
1059 #[test]
1060 fn test_index_should_fail_when_locked() {
1061 let arg = "CRASH_TEST_SECUREVEC_LOCKED";
1062
1063 if std::env::args().any(|a| a == arg) {
1064 let vec: Vec<u8> = vec![1, 2, 3];
1065 let secure = SecureVec::from_vec(vec).unwrap();
1066 let _value = core::hint::black_box(secure[0]);
1067
1068 std::process::exit(1);
1069 }
1070
1071 let child = Command::new(std::env::current_exe().unwrap())
1072 .arg("vec::tests::test_index_should_fail_when_locked")
1073 .arg(arg)
1074 .arg("--nocapture")
1075 .stdout(Stdio::piped())
1076 .stderr(Stdio::piped())
1077 .spawn()
1078 .expect("Failed to spawn child process");
1079
1080 let output = child.wait_with_output().expect("Failed to wait on child");
1081 let status = output.status;
1082
1083 assert!(
1084 !status.success(),
1085 "Process exited successfully with code {:?}, but it should have crashed.",
1086 status.code()
1087 );
1088
1089 #[cfg(unix)]
1090 {
1091 use std::os::unix::process::ExitStatusExt;
1092 let signal = status
1093 .signal()
1094 .expect("Process was not terminated by a signal on Unix.");
1095 assert!(
1096 signal == libc::SIGSEGV || signal == libc::SIGBUS,
1097 "Process terminated with unexpected signal: {}",
1098 signal
1099 );
1100 println!(
1101 "Test passed: Process correctly terminated with signal {}.",
1102 signal
1103 );
1104 }
1105
1106 #[cfg(windows)]
1107 {
1108 const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
1109 assert_eq!(
1110 status.code(),
1111 Some(STATUS_ACCESS_VIOLATION),
1112 "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
1113 status.code()
1114 );
1115 eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
1116 }
1117 }
1118}