secure_types/
vec.rs

1#[cfg(not(feature = "use_os"))]
2use alloc::{Layout, alloc, dealloc};
3
4#[cfg(feature = "use_os")]
5use std::vec::Vec;
6
7use super::{Error, SecureArray};
8use core::{
9   marker::PhantomData,
10   mem,
11   ops::{Bound, RangeBounds},
12   ptr::{self, NonNull},
13};
14use zeroize::{DefaultIsZeroes, Zeroize};
15
16#[cfg(feature = "use_os")]
17use super::{alloc_aligned, page_aligned_size};
18#[cfg(feature = "use_os")]
19use memsec::Prot;
20
21pub type SecureBytes = SecureVec<u8>;
22
23/// A securely allocated, growable vector, just like `std::vec::Vec`.
24///
25/// ## Security Model
26///
27/// When compiled with the `use_os` feature (the default), it provides several layers of protection:
28/// - **Zeroization on Drop**: The memory is zeroized when the vector is dropped.
29/// - **Memory Locking**: The underlying memory pages are locked using `mlock` & `madvise` for (Unix) or
30///   `VirtualLock` & `VirtualProtect` for (Windows) to prevent the OS from memory-dump/swap to disk or other processes accessing the memory.
31/// - **Memory Encryption**: On Windows, the memory is also encrypted using `CryptProtectMemory`.
32///
33/// In a `no_std` environment, it falls back to providing only the **zeroization-on-drop** guarantee.
34///
35/// ## Program Termination
36///
37/// Direct indexing (e.g., `vec[0]`) on a locked vector will cause the operating system
38/// to terminate the process with an access violation error.
39///
40/// Always use the provided scope methods (`unlock_slice`, `unlock_slice_mut`) for safe access.
41///
42/// # Notes
43///
44/// If you return a new allocated `Vec` from one of the unlock methods you are responsible for zeroizing the memory.
45///
46/// # Example
47///
48/// Using `SecureBytes` (a type alias for `SecureVec<u8>`) to handle a secret key.
49///
50/// ```
51/// use secure_types::{SecureBytes, Zeroize};
52///
53/// // Create a new, empty secure vector.
54/// let mut secret_key = SecureBytes::new().unwrap();
55///
56/// // Push some sensitive data into it.
57/// secret_key.push(0xAB);
58/// secret_key.push(0xCD);
59/// secret_key.push(0xEF);
60///
61/// // The memory is locked here.
62///
63/// // Use a scope to safely access the contents as a slice.
64/// secret_key.unlock_slice(|unlocked_slice| {
65///     assert_eq!(unlocked_slice, &[0xAB, 0xCD, 0xEF]);
66/// });
67///
68/// // Not recommended but if you allocate a new Vec make sure to zeroize it
69/// let mut exposed = secret_key.unlock_slice(|unlocked_slice| {
70///     Vec::from(unlocked_slice)
71/// });
72///
73/// // Do what you need to to do with the new vector
74/// // When you are done with it, zeroize it
75/// exposed.zeroize();
76///
77/// // The memory is automatically locked again when the scope ends.
78///
79/// // When `secret_key` is dropped, its memory is securely zeroized.
80/// ```
81pub struct SecureVec<T>
82where
83   T: Zeroize,
84{
85   ptr: NonNull<T>,
86   pub(crate) len: usize,
87   pub(crate) capacity: usize,
88   _marker: PhantomData<T>,
89}
90
91unsafe impl<T: Zeroize + Send> Send for SecureVec<T> {}
92unsafe impl<T: Zeroize + Send + Sync> Sync for SecureVec<T> {}
93
94impl<T: Zeroize> SecureVec<T> {
95   /// Create a new `SecureVec` with a capacity of 1
96   pub fn new() -> Result<Self, Error> {
97      // Give at least a capacity of 1 so encryption/decryption can be done.
98      let capacity = 1;
99      let size = capacity * mem::size_of::<T>();
100      let ptr = alloc_aligned::<T>(size)?;
101
102      let secure = SecureVec {
103         ptr,
104         len: 0,
105         capacity,
106         _marker: PhantomData,
107      };
108
109      let (encrypted, locked) = secure.lock_memory();
110
111      #[cfg(feature = "use_os")]
112      if !locked {
113         return Err(Error::LockFailed);
114      }
115
116      #[cfg(feature = "use_os")]
117      if !encrypted {
118         return Err(Error::CryptProtectMemoryFailed);
119      }
120
121      Ok(secure)
122   }
123
124   /// Create a new `SecureVec` with the given capacity
125   pub fn new_with_capacity(mut capacity: usize) -> Result<Self, Error> {
126      if capacity == 0 {
127         capacity = 1;
128      }
129
130      let size = capacity * mem::size_of::<T>();
131      let ptr = alloc_aligned::<T>(size)?;
132
133      let secure = SecureVec {
134         ptr,
135         len: 0,
136         capacity,
137         _marker: PhantomData,
138      };
139
140      let (encrypted, locked) = secure.lock_memory();
141
142      #[cfg(feature = "use_os")]
143      if !locked {
144         return Err(Error::LockFailed);
145      }
146
147      #[cfg(feature = "use_os")]
148      if !encrypted {
149         return Err(Error::CryptProtectMemoryFailed);
150      }
151
152      Ok(secure)
153   }
154
155   #[cfg(feature = "use_os")]
156   /// Create a new `SecureVec` from a `Vec`
157   ///
158   /// The `Vec` is zeroized afterwards
159   pub fn from_vec(mut vec: Vec<T>) -> Result<Self, Error> {
160      if vec.capacity() == 0 {
161         vec.reserve(1);
162      }
163
164      let capacity = vec.capacity();
165      let len = vec.len();
166
167      let size = capacity * mem::size_of::<T>();
168
169      let ptr = match alloc_aligned::<T>(size) {
170         Ok(ptr) => ptr,
171         Err(_) => {
172            vec.zeroize();
173            return Err(Error::AllocationFailed);
174         }
175      };
176
177      // Copy data from the old pointer to the new one
178      unsafe {
179         core::ptr::copy_nonoverlapping(vec.as_ptr(), ptr.as_ptr() as *mut T, len);
180      }
181
182      vec.zeroize();
183
184      let secure = SecureVec {
185         ptr,
186         len,
187         capacity,
188         _marker: PhantomData,
189      };
190
191      let (encrypted, locked) = secure.lock_memory();
192
193      if !locked {
194         return Err(Error::LockFailed);
195      }
196
197      if !encrypted {
198         return Err(Error::CryptProtectMemoryFailed);
199      }
200
201      Ok(secure)
202   }
203
204   /// Create a new `SecureVec` from a mutable slice.
205   ///
206   /// The slice is zeroized afterwards
207   pub fn from_slice_mut(slice: &mut [T]) -> Result<Self, Error>
208   where
209      T: Clone + DefaultIsZeroes,
210   {
211      let mut secure_vec = SecureVec::new_with_capacity(slice.len())?;
212      secure_vec.len = slice.len();
213      secure_vec.unlock_slice_mut(|dest_slice| {
214         dest_slice.clone_from_slice(slice);
215      });
216      slice.zeroize();
217      Ok(secure_vec)
218   }
219
220   /// Create a new `SecureVec` from a slice.
221   ///
222   /// The slice is not zeroized, you are responsible for zeroizing it
223   pub fn from_slice(slice: &[T]) -> Result<Self, Error>
224   where
225      T: Clone,
226   {
227      let mut secure_vec = SecureVec::new_with_capacity(slice.len())?;
228      secure_vec.len = slice.len();
229      secure_vec.unlock_slice_mut(|dest_slice| {
230         dest_slice.clone_from_slice(slice);
231      });
232      Ok(secure_vec)
233   }
234
235   pub fn len(&self) -> usize {
236      self.len
237   }
238
239   pub fn is_empty(&self) -> bool {
240      self.len() == 0
241   }
242
243   pub fn as_ptr(&self) -> *const T {
244      self.ptr.as_ptr()
245   }
246
247   pub fn as_mut_ptr(&mut self) -> *mut u8 {
248      self.ptr.as_ptr() as *mut u8
249   }
250
251   #[allow(dead_code)]
252   fn aligned_size(&self) -> usize {
253      let size = self.capacity * mem::size_of::<T>();
254      #[cfg(feature = "use_os")]
255      {
256         unsafe { page_aligned_size(size) }
257      }
258      #[cfg(not(feature = "use_os"))]
259      {
260         size // No page alignment in no_std
261      }
262   }
263
264   #[cfg(all(feature = "use_os", windows))]
265   fn encypt_memory(&self) -> bool {
266      let ptr = self.as_ptr() as *mut u8;
267      super::crypt_protect_memory(ptr, self.aligned_size())
268   }
269
270   #[cfg(all(feature = "use_os", windows))]
271   fn decrypt_memory(&self) -> bool {
272      let ptr = self.as_ptr() as *mut u8;
273      super::crypt_unprotect_memory(ptr, self.aligned_size())
274   }
275
276   pub(crate) fn lock_memory(&self) -> (bool, bool) {
277      #[cfg(feature = "use_os")]
278      {
279         #[cfg(windows)]
280         {
281            let encrypt_ok = self.encypt_memory();
282            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
283            (encrypt_ok, mprotect_ok)
284         }
285         #[cfg(unix)]
286         {
287            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
288            (true, mprotect_ok)
289         }
290      }
291      #[cfg(not(feature = "use_os"))]
292      {
293         (true, true) // No-op: always "succeeds"
294      }
295   }
296
297   pub(crate) fn unlock_memory(&self) -> (bool, bool) {
298      #[cfg(feature = "use_os")]
299      {
300         #[cfg(windows)]
301         {
302            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
303            if !mprotect_ok {
304               return (false, false);
305            }
306            let decrypt_ok = self.decrypt_memory();
307            (decrypt_ok, mprotect_ok)
308         }
309         #[cfg(unix)]
310         {
311            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
312            (true, mprotect_ok)
313         }
314      }
315
316      #[cfg(not(feature = "use_os"))]
317      {
318         (true, true) // No-op: always "succeeds"
319      }
320   }
321
322   /// Immutable access to the `SecureVec`
323   pub fn unlock<F, R>(&self, f: F) -> R
324   where
325      F: FnOnce(&SecureVec<T>) -> R,
326   {
327      self.unlock_memory();
328      let result = f(self);
329      self.lock_memory();
330      result
331   }
332
333   /// Immutable access to the `SecureVec` as `&[T]`
334   pub fn unlock_slice<F, R>(&self, f: F) -> R
335   where
336      F: FnOnce(&[T]) -> R,
337   {
338      unsafe {
339         self.unlock_memory();
340         let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
341         let result = f(slice);
342         self.lock_memory();
343         result
344      }
345   }
346
347   /// Mutable access to the `SecureVec` as `&mut [T]`
348   pub fn unlock_slice_mut<F, R>(&mut self, f: F) -> R
349   where
350      F: FnOnce(&mut [T]) -> R,
351   {
352      unsafe {
353         self.unlock_memory();
354         let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
355         let result = f(slice);
356         self.lock_memory();
357         result
358      }
359   }
360
361   /// Immutable access to the `SecureVec` as `Iter<T>`
362   pub fn unlock_iter<F, R>(&self, f: F) -> R
363   where
364      F: FnOnce(core::slice::Iter<T>) -> R,
365   {
366      unsafe {
367         self.unlock_memory();
368         let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
369         let iter = slice.iter();
370         let result = f(iter);
371         self.lock_memory();
372         result
373      }
374   }
375
376   /// Mutable access to the `SecureVec` as `IterMut<T>`
377   pub fn unlock_iter_mut<F, R>(&mut self, f: F) -> R
378   where
379      F: FnOnce(core::slice::IterMut<T>) -> R,
380   {
381      unsafe {
382         self.unlock_memory();
383         let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
384         let iter = slice.iter_mut();
385         let result = f(iter);
386         self.lock_memory();
387         result
388      }
389   }
390
391   /// Erase the underlying data and clears the vector
392   ///
393   /// The memory is locked again and the capacity is preserved for reuse
394   pub fn erase(&mut self) {
395      unsafe {
396         self.unlock_memory();
397         let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
398         for elem in slice.iter_mut() {
399            elem.zeroize();
400         }
401         self.clear();
402         self.lock_memory();
403      }
404   }
405
406   /// Clear the vector
407   ///
408   /// This just sets the vector's len to zero it does not erase the underlying data
409   pub fn clear(&mut self) {
410      self.len = 0;
411   }
412
413   pub fn push(&mut self, value: T) {
414      self.reserve(1);
415
416      self.unlock_memory();
417
418      unsafe {
419         // Write the new value at the end of the vector.
420         core::ptr::write(self.ptr.as_ptr().add(self.len), value);
421
422         self.len += 1;
423      }
424
425      self.lock_memory();
426   }
427
428   /// Ensures that the vector has enough capacity for at least `additional` more elements.
429   ///
430   /// If more capacity is needed, it will reallocate. This may cause the buffer location to change.
431   ///
432   /// # Panics
433   ///
434   /// Panics if the new capacity overflows `usize` or if the allocation fails.
435   pub fn reserve(&mut self, additional: usize) {
436      if self.len() + additional <= self.capacity {
437         return;
438      }
439
440      // Use an amortized growth strategy to avoid reallocating on every push
441      let required_capacity = self.len() + additional;
442      let new_capacity = (self.capacity.max(1) * 2).max(required_capacity);
443
444      let new_size = new_capacity * mem::size_of::<T>();
445      let new_ptr = alloc_aligned::<T>(new_size).expect("Allocation failed");
446
447      // Copy data to new pointer
448      unsafe {
449         self.unlock_memory();
450         core::ptr::copy_nonoverlapping(
451            self.ptr.as_ptr(),
452            new_ptr.as_ptr() as *mut T,
453            self.len(),
454         );
455
456         // Erase and free the old memory
457         if self.capacity > 0 {
458            let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
459            for elem in slice.iter_mut() {
460               elem.zeroize();
461            }
462         }
463
464         #[cfg(feature = "use_os")]
465         memsec::free(self.ptr);
466
467         #[cfg(not(feature = "use_os"))]
468         {
469            let old_size = self.capacity * mem::size_of::<T>();
470            let old_layout = Layout::from_size_align_unchecked(old_size, mem::align_of::<T>());
471            dealloc(self.ptr.as_ptr() as *mut u8, old_layout);
472         }
473      }
474
475      // Update pointer and capacity, then re-lock the new memory region
476      self.ptr = new_ptr;
477      self.capacity = new_capacity;
478      self.lock_memory();
479   }
480
481   /// Creates a draining iterator that removes the specified range from the vector
482   /// and yields the removed items.
483   ///
484   /// Note: The vector is unlocked during the lifetime of the `Drain` iterator.
485   /// The memory is relocked when the `Drain` iterator is dropped.
486   ///
487   /// # Panics
488   /// Panics if the starting point is greater than the end point or if the end point
489   /// is greater than the length of the vector.
490   pub fn drain<R>(&mut self, range: R) -> Drain<'_, T>
491   where
492      R: RangeBounds<usize>,
493   {
494      self.unlock_memory();
495
496      let original_len = self.len;
497
498      let (drain_start_idx, drain_end_idx) = resolve_range_indices(range, original_len);
499
500      let tail_len = original_len - drain_end_idx;
501
502      self.len = drain_start_idx;
503
504      Drain {
505         vec_ref: self,
506         drain_start_index: drain_start_idx,
507         current_drain_iter_index: drain_start_idx,
508         drain_end_index: drain_end_idx,
509         original_vec_len: original_len,
510         tail_len,
511         _marker: PhantomData,
512      }
513   }
514}
515
516impl<T: Clone + Zeroize> Clone for SecureVec<T> {
517   fn clone(&self) -> Self {
518      let mut new_vec = SecureVec::new_with_capacity(self.capacity).unwrap();
519      new_vec.len = self.len;
520
521      self.unlock_slice(|src_slice| {
522         new_vec.unlock_slice_mut(|dest_slice| {
523            dest_slice.clone_from_slice(src_slice);
524         });
525      });
526
527      new_vec
528   }
529}
530
531impl<const LENGTH: usize> From<SecureArray<u8, LENGTH>> for SecureVec<u8> {
532   fn from(array: SecureArray<u8, LENGTH>) -> Self {
533      let mut new_vec = SecureVec::new_with_capacity(LENGTH)
534         .expect("Failed to allocate SecureVec during conversion");
535      new_vec.len = array.len();
536
537      array.unlock(|array_slice| {
538         new_vec.unlock_slice_mut(|vec_slice| {
539            vec_slice.copy_from_slice(array_slice);
540         });
541      });
542
543      new_vec
544   }
545}
546
547impl<T: Zeroize> Drop for SecureVec<T> {
548   fn drop(&mut self) {
549      self.erase();
550      self.unlock_memory();
551      unsafe {
552         #[cfg(feature = "use_os")]
553         memsec::free(self.ptr);
554
555         #[cfg(not(feature = "use_os"))]
556         {
557            let layout =
558               Layout::from_size_align_unchecked(self.allocated_byte_size(), mem::align_of::<T>());
559            dealloc(self.ptr.as_ptr() as *mut u8, layout);
560         }
561      }
562   }
563}
564
565impl<T: Zeroize> core::ops::Index<usize> for SecureVec<T> {
566   type Output = T;
567
568   fn index(&self, index: usize) -> &Self::Output {
569      assert!(index < self.len, "Index out of bounds");
570      unsafe {
571         let ptr = self.ptr.as_ptr().add(index);
572         &*ptr
573      }
574   }
575}
576
577#[cfg(feature = "serde")]
578impl serde::Serialize for SecureVec<u8> {
579   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
580   where
581      S: serde::Serializer,
582   {
583      self.unlock_slice(|slice| serializer.collect_seq(slice.iter()))
584   }
585}
586
587#[cfg(feature = "serde")]
588impl<'de> serde::Deserialize<'de> for SecureVec<u8> {
589   fn deserialize<D>(deserializer: D) -> Result<SecureVec<u8>, D::Error>
590   where
591      D: serde::Deserializer<'de>,
592   {
593      struct SecureVecVisitor;
594      impl<'de> serde::de::Visitor<'de> for SecureVecVisitor {
595         type Value = SecureVec<u8>;
596         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
597            write!(formatter, "a sequence of bytes")
598         }
599         fn visit_seq<A>(
600            self,
601            mut seq: A,
602         ) -> Result<<Self as serde::de::Visitor<'de>>::Value, A::Error>
603         where
604            A: serde::de::SeqAccess<'de>,
605         {
606            let mut vec = Vec::new();
607            while let Some(byte) = seq.next_element::<u8>()? {
608               vec.push(byte);
609            }
610            SecureVec::from_vec(vec).map_err(serde::de::Error::custom)
611         }
612      }
613      deserializer.deserialize_seq(SecureVecVisitor)
614   }
615}
616
617/// A draining iterator for `SecureVec<T>`.
618///
619/// This struct is created by the `drain` method on `SecureVec`.
620///
621/// # Safety
622/// The returned `Drain` iterator must not be forgotten (via `mem::forget`).
623/// Forgetting the iterator sets the len of `SecureVec` to 0 and the memory will remain unlocked
624pub struct Drain<'a, T: Zeroize + 'a> {
625   vec_ref: &'a mut SecureVec<T>,
626   drain_start_index: usize,
627   current_drain_iter_index: usize,
628   drain_end_index: usize,
629
630   original_vec_len: usize, // Original length of vec_ref before drain
631   tail_len: usize,         // Number of elements after the drain range in the original vec
632
633   _marker: PhantomData<&'a T>,
634}
635
636impl<'a, T: Zeroize> Iterator for Drain<'a, T> {
637   type Item = T;
638
639   fn next(&mut self) -> Option<T> {
640      if self.current_drain_iter_index < self.drain_end_index {
641         // SecureVec is already unlocked by the `drain` method.
642         unsafe {
643            let item_ptr = self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index);
644            let item = ptr::read(item_ptr);
645            self.current_drain_iter_index += 1;
646            Some(item)
647         }
648      } else {
649         None
650      }
651   }
652
653   fn size_hint(&self) -> (usize, Option<usize>) {
654      let remaining = self.drain_end_index - self.current_drain_iter_index;
655      (remaining, Some(remaining))
656   }
657}
658
659impl<'a, T: Zeroize> ExactSizeIterator for Drain<'a, T> {}
660
661impl<'a, T: Zeroize> Drop for Drain<'a, T> {
662   fn drop(&mut self) {
663      unsafe {
664         // The vec_ref's memory is currently unlocked.
665         if mem::needs_drop::<T>() {
666            let mut current_ptr = self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index);
667            let end_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index);
668            while current_ptr < end_ptr {
669               ptr::drop_in_place(current_ptr);
670               current_ptr = current_ptr.add(1);
671            }
672         }
673
674         let hole_dst_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_start_index);
675         let tail_src_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index);
676
677         if self.tail_len > 0 {
678            ptr::copy(tail_src_ptr, hole_dst_ptr, self.tail_len);
679         }
680
681         // The new length of the vector.
682         let new_len = self.drain_start_index + self.tail_len;
683
684         // Process the memory region that is no longer part of the active vector's content.
685         // This region is from `vec_ref.ptr + new_len` up to `vec_ref.ptr + original_vec_len`.
686         // It contains:
687         //    a) Original data of the latter part of the drained slice (if not overwritten by tail).
688         //       These were dropped in step 1 if T:Drop.
689         //    b) Original data of the tail items (which have now been copied).
690         //       These need to be dropped if T:Drop, as ptr::copy doesn't drop the source.
691         // After any necessary drops, this entire region must be zeroized.
692
693         let mut current_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(new_len);
694         let end_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(self.original_vec_len);
695
696         // Determine the start of the original tail's memory region
697         let original_tail_start_ptr_val = tail_src_ptr as usize;
698
699         while current_cleanup_ptr < end_cleanup_ptr {
700            if mem::needs_drop::<T>() {
701               let current_ptr_val = current_cleanup_ptr as usize;
702               let original_tail_end_ptr_val =
703                  original_tail_start_ptr_val + self.tail_len * mem::size_of::<T>();
704
705               if current_ptr_val >= original_tail_start_ptr_val
706                  && current_ptr_val < original_tail_end_ptr_val
707               {
708                  // This element was part of the original tail. ptr::copy moved its value.
709                  // The original instance here needs to be dropped.
710                  ptr::drop_in_place(current_cleanup_ptr);
711               }
712               // Else, it was part of the drained range (not covered by tail move).
713               // If it needed dropping, it was handled in step 1.
714            }
715
716            // Zeroize the memory of this element.
717            (*current_cleanup_ptr).zeroize();
718            current_cleanup_ptr = current_cleanup_ptr.add(1);
719         }
720
721         // Update the SecureVec's length.
722         self.vec_ref.len = new_len;
723
724         // Relock the SecureVec's memory.
725         self.vec_ref.lock_memory();
726      }
727   }
728}
729
730// Helper function to resolve RangeBounds to (start, end) indices
731fn resolve_range_indices<R: RangeBounds<usize>>(range: R, len: usize) -> (usize, usize) {
732   let start_bound = range.start_bound();
733   let end_bound = range.end_bound();
734
735   let start = match start_bound {
736      Bound::Included(&s) => s,
737      Bound::Excluded(&s) => s
738         .checked_add(1)
739         .unwrap_or_else(|| panic!("attempted to start drain at Excluded(usize::MAX)")),
740      Bound::Unbounded => 0,
741   };
742
743   let end = match end_bound {
744      Bound::Included(&e) => e
745         .checked_add(1)
746         .unwrap_or_else(|| panic!("attempted to end drain at Included(usize::MAX)")),
747      Bound::Excluded(&e) => e,
748      Bound::Unbounded => len,
749   };
750
751   if start > end {
752      panic!(
753         "drain range start ({}) must be less than or equal to end ({})",
754         start, end
755      );
756   }
757   if end > len {
758      panic!(
759         "drain range end ({}) out of bounds for slice of length {}",
760         end, len
761      );
762   }
763
764   (start, end)
765}
766
767#[cfg(all(test, feature = "use_os"))]
768mod tests {
769   use super::*;
770   use std::process::{Command, Stdio};
771   use std::sync::{Arc, Mutex};
772
773   #[test]
774   fn test_creation() {
775      let vec: Vec<u8> = vec![1, 2, 3];
776      let secure_vec = SecureVec::from_vec(vec).unwrap();
777
778      secure_vec.unlock_slice(|slice| {
779         assert_eq!(slice, &[1, 2, 3]);
780      });
781
782      let exposed_slice = &mut [1, 2, 3];
783      let secure_slice = SecureVec::from_slice_mut(exposed_slice).unwrap();
784      assert_eq!(exposed_slice, &[0u8; 3]);
785
786      secure_slice.unlock_slice(|slice| {
787         assert_eq!(slice, &[1, 2, 3]);
788      });
789
790      let exposed_slice = [1, 2, 3];
791      let secure_slice = SecureVec::from_slice(&exposed_slice).unwrap();
792
793      secure_slice.unlock_slice(|slice| {
794         assert_eq!(slice, exposed_slice);
795      });
796   }
797
798   #[test]
799   fn test_from_secure_array() {
800      let exposed: &mut [u8; 3] = &mut [1, 2, 3];
801      let array: SecureArray<u8, 3> = SecureArray::from_slice_mut(exposed).unwrap();
802      let vec: SecureVec<u8> = array.into();
803      assert_eq!(vec.len(), 3);
804      vec.unlock_slice(|slice| {
805         assert_eq!(slice, &[1, 2, 3]);
806      });
807   }
808
809   #[test]
810   fn test_size_cannot_be_zero() {
811      let secure: SecureVec<u8> = SecureVec::new().unwrap();
812      let size = secure.aligned_size();
813      assert_eq!(size > 0, true);
814
815      let secure: SecureVec<u8> = SecureVec::from_vec(vec![]).unwrap();
816      let size = secure.aligned_size();
817      assert_eq!(size > 0, true);
818
819      let secure: SecureVec<u8> = SecureVec::new_with_capacity(0).unwrap();
820      let size = secure.aligned_size();
821      assert_eq!(size > 0, true);
822   }
823
824   #[test]
825   fn lock_unlock_works() {
826      let secure: SecureVec<u8> = SecureVec::new().unwrap();
827
828      let (decrypted, unlocked) = secure.unlock_memory();
829      assert!(decrypted);
830      assert!(unlocked);
831
832      let (encrypted, locked) = secure.lock_memory();
833      assert!(encrypted);
834      assert!(locked);
835   }
836
837   #[test]
838   fn test_thread_safety() {
839      let vec: Vec<u8> = vec![];
840      let secure = SecureVec::from_vec(vec).unwrap();
841      let secure = Arc::new(Mutex::new(secure));
842
843      let mut handles = Vec::new();
844      for i in 0..5u8 {
845         let secure_clone = secure.clone();
846         let handle = std::thread::spawn(move || {
847            let mut secure = secure_clone.lock().unwrap();
848            secure.push(i);
849         });
850         handles.push(handle);
851      }
852
853      for handle in handles {
854         handle.join().unwrap();
855      }
856
857      let mut sec = secure.lock().unwrap();
858      sec.unlock_slice_mut(|slice| {
859         slice.sort();
860         assert_eq!(slice.len(), 5);
861         assert_eq!(slice, &[0, 1, 2, 3, 4]);
862      });
863   }
864
865   #[test]
866   fn test_clone() {
867      let vec: Vec<u8> = vec![1, 2, 3];
868      let secure1 = SecureVec::from_vec(vec).unwrap();
869      let secure2 = secure1.clone();
870
871      secure1.unlock_slice(|slice| {
872         secure2.unlock_slice(|slice2| {
873            assert_eq!(slice, slice2);
874         });
875      });
876   }
877
878   #[test]
879   fn test_do_not_call_forget_on_drain() {
880      let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
881      let mut secure = SecureVec::from_vec(vec).unwrap();
882      let drain = secure.drain(..3);
883      core::mem::forget(drain);
884      // we can still use secure vec but its state is unreachable
885      secure.unlock_slice(|secure| {
886         assert_eq!(secure.len(), 0);
887      });
888   }
889
890   #[test]
891   fn test_drain() {
892      let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
893      let mut secure = SecureVec::from_vec(vec).unwrap();
894      let mut drain = secure.drain(..3);
895      assert_eq!(drain.next(), Some(1));
896      assert_eq!(drain.next(), Some(2));
897      assert_eq!(drain.next(), Some(3));
898      assert_eq!(drain.next(), None);
899      drop(drain);
900      secure.unlock_slice(|secure| {
901         assert_eq!(secure.len(), 7);
902         assert_eq!(secure, &[4, 5, 6, 7, 8, 9, 10]);
903      });
904   }
905
906   #[cfg(feature = "serde")]
907   #[test]
908   fn test_secure_vec_serde() {
909      let vec: Vec<u8> = vec![1, 2, 3];
910      let secure = SecureVec::from_vec(vec).unwrap();
911      let json = serde_json::to_vec(&secure).expect("Serialization failed");
912      let deserialized: SecureVec<u8> =
913         serde_json::from_slice(&json).expect("Deserialization failed");
914      deserialized.unlock_slice(|slice| {
915         assert_eq!(slice, &[1, 2, 3]);
916      });
917   }
918
919   #[test]
920   fn test_erase() {
921      let vec: Vec<u8> = vec![1, 2, 3];
922      let mut secure = SecureVec::from_vec(vec).unwrap();
923      secure.erase();
924
925      secure.unlock(|secure| {
926         assert_eq!(secure.len, 0);
927         assert_eq!(secure.capacity, 3);
928      });
929
930      secure.unlock_iter(|iter| {
931         for elem in iter {
932            assert_eq!(elem, &0);
933         }
934      });
935   }
936
937   #[test]
938   fn test_push() {
939      let vec: Vec<u8> = Vec::new();
940      let mut secure = SecureVec::from_vec(vec).unwrap();
941      for i in 0..10 {
942         secure.push(i);
943      }
944
945      assert_eq!(secure.len(), 10);
946
947      secure.unlock_slice(|slice| {
948         assert_eq!(slice, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
949      });
950   }
951
952   #[test]
953   fn test_reserve() {
954      let mut secure: SecureVec<u8> = SecureVec::new().unwrap();
955      secure.reserve(10);
956      assert_eq!(secure.capacity, 10);
957   }
958
959   #[test]
960   fn test_reserve_doubling() {
961      let mut secure: SecureVec<u8> = SecureVec::new().unwrap();
962      secure.reserve(10);
963
964      for i in 0..9 {
965         secure.push(i);
966      }
967
968      secure.push(9);
969      assert_eq!(secure.capacity, 10);
970      assert_eq!(secure.len(), 10);
971
972      secure.push(10);
973      assert_eq!(secure.capacity, 20);
974      assert_eq!(secure.len(), 11);
975   }
976
977   #[test]
978   fn test_index() {
979      let vec: Vec<u8> = vec![1, 2, 3];
980      let secure = SecureVec::from_vec(vec).unwrap();
981      secure.unlock(|secure| {
982         assert_eq!(secure[0], 1);
983         assert_eq!(secure[1], 2);
984         assert_eq!(secure[2], 3);
985      });
986   }
987
988   #[test]
989   fn test_unlock_slice() {
990      let vec: Vec<u8> = vec![1, 2, 3];
991      let secure = SecureVec::from_vec(vec).unwrap();
992      secure.unlock_slice(|slice| {
993         assert_eq!(slice, &[1, 2, 3]);
994      });
995   }
996
997   #[test]
998   fn test_unlock_slice_mut() {
999      let vec: Vec<u8> = vec![1, 2, 3];
1000      let mut secure = SecureVec::from_vec(vec).unwrap();
1001
1002      secure.unlock_slice_mut(|slice| {
1003         slice[0] = 4;
1004         assert_eq!(slice, &mut [4, 2, 3]);
1005      });
1006   }
1007
1008   #[test]
1009   fn test_unlock_iter() {
1010      let vec: Vec<u8> = vec![1, 2, 3];
1011      let secure = SecureVec::from_vec(vec).unwrap();
1012      let sum: u8 = secure.unlock_iter(|iter| iter.map(|&x| x).sum());
1013
1014      assert_eq!(sum, 6);
1015
1016      let secure: SecureVec<u8> = SecureVec::new_with_capacity(3).unwrap();
1017      let sum: u8 = secure.unlock_iter(|iter| iter.map(|&x| x).sum());
1018
1019      assert_eq!(sum, 0);
1020   }
1021
1022   #[test]
1023   fn test_unlock_iter_mut() {
1024      let vec: Vec<u8> = vec![1, 2, 3];
1025      let mut secure = SecureVec::from_vec(vec).unwrap();
1026      secure.unlock_iter_mut(|iter| {
1027         for elem in iter {
1028            *elem += 1;
1029         }
1030      });
1031
1032      secure.unlock_slice(|slice| {
1033         assert_eq!(slice, &[2, 3, 4]);
1034      });
1035   }
1036
1037   #[test]
1038   fn test_index_should_fail_when_locked() {
1039      let arg = "CRASH_TEST_SECUREVEC_LOCKED";
1040
1041      if std::env::args().any(|a| a == arg) {
1042         let vec: Vec<u8> = vec![1, 2, 3];
1043         let secure = SecureVec::from_vec(vec).unwrap();
1044         let _value = core::hint::black_box(secure[0]);
1045
1046         std::process::exit(1);
1047      }
1048
1049      let child = Command::new(std::env::current_exe().unwrap())
1050         .arg("vec::tests::test_index_should_fail_when_locked")
1051         .arg(arg)
1052         .arg("--nocapture")
1053         .stdout(Stdio::piped())
1054         .stderr(Stdio::piped())
1055         .spawn()
1056         .expect("Failed to spawn child process");
1057
1058      let output = child.wait_with_output().expect("Failed to wait on child");
1059      let status = output.status;
1060
1061      assert!(
1062         !status.success(),
1063         "Process exited successfully with code {:?}, but it should have crashed.",
1064         status.code()
1065      );
1066
1067      #[cfg(unix)]
1068      {
1069         use std::os::unix::process::ExitStatusExt;
1070         let signal = status
1071            .signal()
1072            .expect("Process was not terminated by a signal on Unix.");
1073         assert!(
1074            signal == libc::SIGSEGV || signal == libc::SIGBUS,
1075            "Process terminated with unexpected signal: {}",
1076            signal
1077         );
1078         println!(
1079            "Test passed: Process correctly terminated with signal {}.",
1080            signal
1081         );
1082      }
1083
1084      #[cfg(windows)]
1085      {
1086         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
1087         assert_eq!(
1088            status.code(),
1089            Some(STATUS_ACCESS_VIOLATION),
1090            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
1091            status.code()
1092         );
1093         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
1094      }
1095   }
1096}