secure_types/
vec.rs

1#[cfg(not(feature = "std"))]
2use alloc::{Layout, alloc, dealloc};
3
4#[cfg(feature = "std")]
5use std::vec::Vec;
6
7use super::Error;
8use core::{
9   marker::PhantomData,
10   mem,
11   ops::{Bound, RangeBounds},
12   ptr::{self, NonNull},
13};
14use zeroize::Zeroize;
15
16#[cfg(feature = "std")]
17use super::page_size;
18#[cfg(feature = "std")]
19use memsec::Prot;
20
21pub type SecureBytes = SecureVec<u8>;
22
23/// A securely allocated, growable vector, analogous to `std::vec::Vec`.
24///
25/// `SecureVec<T>` is designed to hold a sequence of sensitive data elements. It serves as the
26/// foundational secure collection in this crate.
27///
28/// ## Security Model
29///
30/// When compiled with the `std` feature (the default), it provides several layers of protection:
31/// - **Zeroization on Drop**: The memory region is securely zeroized when the vector is dropped.
32/// - **Memory Locking**: The underlying memory pages are locked using `mlock` (Unix) or
33///   `VirtualLock` (Windows) to prevent the OS from swapping them to disk.
34/// - **Memory Encryption**: On Windows, the memory is also encrypted using `CryptProtectMemory`.
35///
36/// In a `no_std` environment, it falls back to providing only the **zeroization-on-drop** guarantee.
37///
38/// ## Program Termination
39///
40/// Direct indexing (e.g., `vec[0]`) on a locked vector will cause the operating system
41/// to terminate the process with an access violation error. This is by design.
42///
43/// Always use the provided scope methods (`slice_scope`, `slice_mut_scope`) for safe access.
44///
45/// # Examples
46///
47/// Using `SecureBytes` (a type alias for `SecureVec<u8>`) to handle a secret key.
48///
49/// ```
50/// use secure_types::SecureBytes;
51///
52/// // Create a new, empty secure vector.
53/// let mut secret_key = SecureBytes::new().unwrap();
54///
55/// // Push some sensitive data into it.
56/// secret_key.push(0xAB);
57/// secret_key.push(0xCD);
58/// secret_key.push(0xEF);
59///
60/// // The memory is locked here.
61///
62/// // Use a scope to safely access the contents as a slice.
63/// secret_key.slice_scope(|unlocked_slice| {
64///     assert_eq!(unlocked_slice, &[0xAB, 0xCD, 0xEF]);
65///     println!("Secret Key: {:?}", unlocked_slice);
66/// });
67///
68/// // The memory is automatically locked again when the scope ends.
69///
70/// // When `secret_key` is dropped, its memory is securely zeroized.
71/// ```
72pub struct SecureVec<T>
73where
74   T: Zeroize,
75{
76   ptr: NonNull<T>,
77   pub(crate) len: usize,
78   pub(crate) capacity: usize,
79   _marker: PhantomData<T>,
80}
81
82unsafe impl<T: Zeroize + Send> Send for SecureVec<T> {}
83unsafe impl<T: Zeroize + Send + Sync> Sync for SecureVec<T> {}
84
85impl<T: Zeroize> SecureVec<T> {
86   pub fn new() -> Result<Self, Error> {
87      // Give at least a capacity of 1 so encryption/decryption can be done.
88      let capacity = 1;
89      let size = capacity * mem::size_of::<T>();
90
91      #[cfg(feature = "std")]
92      let ptr = unsafe {
93         let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
94         let allocated_ptr = memsec::malloc_sized(aligned_size);
95         let ptr = allocated_ptr.ok_or(Error::AllocationFailed)?;
96         ptr.as_ptr() as *mut T
97      };
98
99      #[cfg(not(feature = "std"))]
100      let ptr = {
101         let layout = Layout::from_size_align(size, mem::align_of::<T>())
102            .map_err(|_| Error::AllocationFailed)?;
103         let ptr = unsafe { alloc::alloc(layout) as *mut T };
104         if ptr.is_null() {
105            return Err(Error::AllocationFailed);
106         }
107         ptr
108      };
109
110      let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
111      let secure = SecureVec {
112         ptr: non_null,
113         len: 0,
114         capacity,
115         _marker: PhantomData,
116      };
117
118      let (encrypted, locked) = secure.lock_memory();
119
120      #[cfg(feature = "std")]
121      if !locked {
122         return Err(Error::LockFailed);
123      }
124
125      #[cfg(feature = "std")]
126      if !encrypted {
127         return Err(Error::CryptProtectMemoryFailed);
128      }
129
130      Ok(secure)
131   }
132
133   pub fn with_capacity(mut capacity: usize) -> Result<Self, Error> {
134      if capacity == 0 {
135         capacity = 1;
136      }
137
138      let size = capacity * mem::size_of::<T>();
139
140      #[cfg(feature = "std")]
141      let ptr = unsafe {
142         let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
143         let allocated_ptr = memsec::malloc_sized(aligned_size);
144         let ptr = allocated_ptr.ok_or(Error::AllocationFailed)?;
145         ptr.as_ptr() as *mut T
146      };
147
148      #[cfg(not(feature = "std"))]
149      let ptr = {
150         let layout = Layout::from_size_align(size, mem::align_of::<T>())
151            .map_err(|_| Error::AllocationFailed)?;
152         let ptr = unsafe { alloc::alloc(layout) as *mut T };
153         if ptr.is_null() {
154            return Err(Error::AllocationFailed);
155         }
156         ptr
157      };
158
159      let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
160
161      let secure = SecureVec {
162         ptr: non_null,
163         len: 0,
164         capacity,
165         _marker: PhantomData,
166      };
167
168      let (encrypted, locked) = secure.lock_memory();
169
170      #[cfg(feature = "std")]
171      if !locked {
172         return Err(Error::LockFailed);
173      }
174
175      #[cfg(feature = "std")]
176      if !encrypted {
177         return Err(Error::CryptProtectMemoryFailed);
178      }
179
180      Ok(secure)
181   }
182
183   #[cfg(feature = "std")]
184   pub fn from_vec(mut vec: Vec<T>) -> Result<Self, Error> {
185      if vec.capacity() == 0 {
186         vec.reserve(1);
187      }
188
189      let capacity = vec.capacity();
190      let len = vec.len();
191
192      // Allocate memory
193      let size = capacity * mem::size_of::<T>();
194
195      let ptr = unsafe {
196         let aligned_size = (size + page_size() - 1) & !(page_size() - 1);
197         let allocated_ptr = memsec::malloc_sized(aligned_size);
198         if allocated_ptr.is_none() {
199            vec.zeroize();
200            return Err(Error::AllocationFailed);
201         } else {
202            allocated_ptr.unwrap().as_ptr() as *mut T
203         }
204      };
205
206      // Copy data from the old pointer to the new one
207      unsafe {
208         core::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, len);
209      }
210
211      // Zeroize and drop the original Vec
212      vec.zeroize();
213      drop(vec);
214
215      let non_null = NonNull::new(ptr).ok_or(Error::NullAllocation)?;
216
217      let secure = SecureVec {
218         ptr: non_null,
219         len,
220         capacity,
221         _marker: PhantomData,
222      };
223
224      let (encrypted, locked) = secure.lock_memory();
225
226      if !locked {
227         return Err(Error::LockFailed);
228      }
229
230      if !encrypted {
231         return Err(Error::CryptProtectMemoryFailed);
232      }
233
234      Ok(secure)
235   }
236
237   pub fn len(&self) -> usize {
238      self.len
239   }
240
241   pub fn as_ptr(&self) -> *const T {
242      self.ptr.as_ptr()
243   }
244
245   pub fn as_mut_ptr(&mut self) -> *mut u8 {
246      self.ptr.as_ptr() as *mut u8
247   }
248
249   #[allow(dead_code)]
250   fn allocated_byte_size(&self) -> usize {
251      let size = self.capacity * mem::size_of::<T>();
252      #[cfg(feature = "std")]
253      {
254         (size + page_size() - 1) & !(page_size() - 1)
255      }
256      #[cfg(not(feature = "std"))]
257      {
258         size // No page alignment in no_std
259      }
260   }
261
262   #[cfg(all(feature = "std", windows))]
263   fn encypt_memory(&self) -> bool {
264      let ptr = self.as_ptr() as *mut u8;
265      super::crypt_protect_memory(ptr, self.allocated_byte_size())
266   }
267
268   #[cfg(all(feature = "std", windows))]
269   fn decrypt_memory(&self) -> bool {
270      let ptr = self.as_ptr() as *mut u8;
271      super::crypt_unprotect_memory(ptr, self.allocated_byte_size())
272   }
273
274   /// Lock the memory region
275   ///
276   /// On Windows also calls `CryptProtectMemory` to encrypt the memory
277   ///
278   /// On Unix it just calls `mprotect` to lock the memory
279   pub(crate) fn lock_memory(&self) -> (bool, bool) {
280      #[cfg(feature = "std")]
281      {
282         #[cfg(windows)]
283         {
284            let encrypt_ok = self.encypt_memory();
285            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
286            (encrypt_ok, mprotect_ok)
287         }
288         #[cfg(unix)]
289         {
290            let mprotect_ok = super::mprotect(self.ptr, Prot::NoAccess);
291            (true, mprotect_ok)
292         }
293      }
294      #[cfg(not(feature = "std"))]
295      {
296         (true, true) // No-op: always "succeeds"
297      }
298   }
299
300   /// Unlock the memory region
301   ///
302   /// On Windows also calls `CryptUnprotectMemory` to decrypt the memory
303   ///
304   /// On Unix it just calls `mprotect` to unlock the memory
305   pub(crate) fn unlock_memory(&self) -> (bool, bool) {
306      #[cfg(feature = "std")]
307      {
308         #[cfg(windows)]
309         {
310            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
311            if !mprotect_ok {
312               return (false, false);
313            }
314            let decrypt_ok = self.decrypt_memory();
315            (decrypt_ok, mprotect_ok)
316         }
317         #[cfg(unix)]
318         {
319            let mprotect_ok = super::mprotect(self.ptr, Prot::ReadWrite);
320            (true, mprotect_ok)
321         }
322      }
323
324      #[cfg(not(feature = "std"))]
325      {
326         (true, true) // No-op: always "succeeds"
327      }
328   }
329
330   /// Immutable access to the `SecureVec`
331   pub fn unlock_scope<F, R>(&self, f: F) -> R
332   where
333      F: FnOnce(&SecureVec<T>) -> R,
334   {
335      self.unlock_memory();
336      let result = f(self);
337      self.lock_memory();
338      result
339   }
340
341   /// Immutable access to the `SecureVec` as `&[T]`
342   pub fn slice_scope<F, R>(&self, f: F) -> R
343   where
344      F: FnOnce(&[T]) -> R,
345   {
346      unsafe {
347         self.unlock_memory();
348         let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
349         let result = f(slice);
350         self.lock_memory();
351         result
352      }
353   }
354
355   /// Mutable access to the `SecureVec` as `&mut [T]`
356   pub fn slice_mut_scope<F, R>(&mut self, f: F) -> R
357   where
358      F: FnOnce(&mut [T]) -> R,
359   {
360      unsafe {
361         self.unlock_memory();
362         let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
363         let result = f(slice);
364         self.lock_memory();
365         result
366      }
367   }
368
369   /// Immutable access to the `SecureVec` as `Iter<T>`
370   ///
371   /// ## Use with caution
372   ///
373   /// You can actually return a new allocated `Vec` from this function
374   ///
375   /// If you do that you are responsible for zeroizing its contents
376   pub fn iter_scope<F, R>(&self, f: F) -> R
377   where
378      F: FnOnce(core::slice::Iter<T>) -> R,
379   {
380      unsafe {
381         self.unlock_memory();
382         let slice = core::slice::from_raw_parts(self.ptr.as_ptr(), self.len);
383         let iter = slice.iter();
384         let result = f(iter);
385         self.lock_memory();
386         result
387      }
388   }
389
390   /// Mutable access to the `SecureVec` as `IterMut<T>`
391   ///
392   /// ## Use with caution
393   ///
394   /// You can actually return a new allocated `Vec` from this function
395   ///
396   /// If you do that you are responsible for zeroizing its contents
397   pub fn iter_mut_scope<F, R>(&mut self, f: F) -> R
398   where
399      F: FnOnce(core::slice::IterMut<T>) -> R,
400   {
401      unsafe {
402         self.unlock_memory();
403         let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
404         let iter = slice.iter_mut();
405         let result = f(iter);
406         self.lock_memory();
407         result
408      }
409   }
410
411   /// Erase the underlying data and clears the vector
412   ///
413   /// The memory is locked again and the capacity is preserved for reuse
414   pub fn erase(&mut self) {
415      unsafe {
416         self.unlock_memory();
417         let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
418         for elem in slice.iter_mut() {
419            elem.zeroize();
420         }
421         self.clear();
422         self.lock_memory();
423      }
424   }
425
426   /// Clear the vector
427   ///
428   /// This just sets the vector's len to zero it does not erase the underlying data
429   pub fn clear(&mut self) {
430      self.len = 0;
431   }
432
433   pub fn push(&mut self, value: T) {
434      if self.len >= self.capacity {
435         // Reallocate
436         let new_capacity = if self.capacity == 0 {
437            1
438         } else {
439            self.capacity + 1
440         };
441
442         let items_byte_size = new_capacity * mem::size_of::<T>();
443
444         // Allocate new memory
445         #[cfg(feature = "std")]
446         let new_ptr = unsafe {
447            let aligned_allocation_size = (items_byte_size + page_size() - 1) & !(page_size() - 1);
448            memsec::malloc_sized(aligned_allocation_size)
449               .expect("Failed to allocate memory")
450               .as_ptr() as *mut T
451         };
452
453         #[cfg(not(feature = "std"))]
454         let new_ptr = {
455            let layout = Layout::from_size_align(items_byte_size, mem::align_of::<T>())
456               .expect("Failed to allocate memory");
457            let ptr = unsafe { alloc::alloc(layout) as *mut T };
458            if ptr.is_null() {
459               panic!("Null pointer returned from alloc");
460            }
461            ptr
462         };
463
464         // Copy data to new pointer, erase and free old memory
465         unsafe {
466            self.unlock_memory();
467            core::ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr, self.len);
468            if self.capacity > 0 {
469               let slice = core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len);
470               for elem in slice.iter_mut() {
471                  elem.zeroize();
472               }
473            }
474            #[cfg(feature = "std")]
475            memsec::free(self.ptr);
476
477            #[cfg(not(feature = "std"))]
478            {
479               let old_size = self.capacity * mem::size_of::<T>();
480               let old_layout = Layout::from_size_align_unchecked(old_size, mem::align_of::<T>());
481               dealloc(self.ptr.as_ptr() as *mut u8, old_layout);
482            }
483         }
484
485         // Update pointer and capacity
486         self.ptr = NonNull::new(new_ptr).expect("Failed to create NonNull");
487         self.capacity = new_capacity;
488
489         // write and lock
490         unsafe {
491            core::ptr::write(self.ptr.as_ptr().add(self.len), value);
492            self.len += 1;
493            self.lock_memory();
494         }
495      } else {
496         // Unlock, write, and relock
497         unsafe {
498            self.unlock_memory();
499            core::ptr::write(self.ptr.as_ptr().add(self.len), value);
500            self.len += 1;
501            self.lock_memory();
502         }
503      }
504   }
505
506   /// Creates a draining iterator that removes the specified range from the vector
507   /// and yields the removed items.
508   ///
509   /// Note: The vector is unlocked during the lifetime of the `Drain` iterator.
510   /// The memory is relocked when the `Drain` iterator is dropped.
511   ///
512   /// # Panics
513   /// Panics if the starting point is greater than the end point or if the end point
514   /// is greater than the length of the vector.
515   pub fn drain<R>(&mut self, range: R) -> Drain<'_, T>
516   where
517      R: RangeBounds<usize>,
518   {
519      self.unlock_memory();
520
521      let original_len = self.len;
522
523      let (drain_start_idx, drain_end_idx) = resolve_range_indices(range, original_len);
524
525      let tail_len = original_len - drain_end_idx;
526
527      self.len = drain_start_idx;
528
529      Drain {
530         vec_ref: self,
531         drain_start_index: drain_start_idx,
532         current_drain_iter_index: drain_start_idx,
533         drain_end_index: drain_end_idx,
534         original_vec_len: original_len,
535         tail_len,
536         _marker: PhantomData,
537      }
538   }
539}
540
541impl<T: Clone + Zeroize> Clone for SecureVec<T> {
542   fn clone(&self) -> Self {
543      let mut new_vec: SecureVec<T> = SecureVec::with_capacity(self.capacity).unwrap();
544
545      new_vec.unlock_memory();
546      self.unlock_memory();
547
548      unsafe {
549         for i in 0..self.len {
550            let value = (*self.ptr.as_ptr().add(i)).clone();
551            core::ptr::write(new_vec.ptr.as_ptr().add(i), value);
552         }
553      }
554
555      new_vec.len = self.len;
556      self.lock_memory();
557      new_vec.lock_memory();
558      new_vec
559   }
560}
561
562impl<T: Zeroize> Drop for SecureVec<T> {
563   fn drop(&mut self) {
564      self.erase();
565      self.unlock_memory();
566      unsafe {
567         #[cfg(feature = "std")]
568         memsec::free(self.ptr);
569
570         #[cfg(not(feature = "std"))]
571         {
572            // Recreate the layout to deallocate correctly
573            let layout =
574               Layout::from_size_align_unchecked(self.allocated_byte_size(), mem::align_of::<T>());
575            dealloc(self.ptr.as_ptr() as *mut u8, layout);
576         }
577      }
578   }
579}
580
581impl<T: Zeroize> core::ops::Index<usize> for SecureVec<T> {
582   type Output = T;
583
584   fn index(&self, index: usize) -> &Self::Output {
585      assert!(index < self.len, "Index out of bounds");
586      unsafe {
587         let ptr = self.ptr.as_ptr().add(index);
588         let reference = &*ptr;
589         reference
590      }
591   }
592}
593
594#[cfg(feature = "serde")]
595impl serde::Serialize for SecureVec<u8> {
596   fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
597   where
598      S: serde::Serializer,
599   {
600      let res = self.slice_scope(|slice| serializer.collect_seq(slice.iter()));
601      res
602   }
603}
604
605#[cfg(feature = "serde")]
606impl<'de> serde::Deserialize<'de> for SecureVec<u8> {
607   fn deserialize<D>(deserializer: D) -> Result<SecureVec<u8>, D::Error>
608   where
609      D: serde::Deserializer<'de>,
610   {
611      struct SecureVecVisitor;
612      impl<'de> serde::de::Visitor<'de> for SecureVecVisitor {
613         type Value = SecureVec<u8>;
614         fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
615            write!(formatter, "a sequence of bytes")
616         }
617         fn visit_seq<A>(
618            self,
619            mut seq: A,
620         ) -> Result<<Self as serde::de::Visitor<'de>>::Value, A::Error>
621         where
622            A: serde::de::SeqAccess<'de>,
623         {
624            let mut vec = Vec::new();
625            while let Some(byte) = seq.next_element::<u8>()? {
626               vec.push(byte);
627            }
628            SecureVec::from_vec(vec).map_err(serde::de::Error::custom)
629         }
630      }
631      deserializer.deserialize_seq(SecureVecVisitor)
632   }
633}
634
635/// A draining iterator for `SecureVec<T>`.
636///
637/// This struct is created by the `drain` method on `SecureVec`.
638///
639/// # Safety
640///
641/// The `Drain` iterator relies on being dropped to correctly handle memory
642/// (moving tail elements, zeroizing drained portions, and relocking memory).
643/// If `mem::forget` is called on `Drain`, the `SecureVec` will have its length
644/// zeroed, but the memory for the drained elements and tail might not be
645/// properly zeroized or relocked, potentially leading to data exposure if
646/// `memsec::free` doesn't zeroize.
647pub struct Drain<'a, T: Zeroize + 'a> {
648   vec_ref: &'a mut SecureVec<T>,
649   drain_start_index: usize,
650   current_drain_iter_index: usize,
651   drain_end_index: usize,
652
653   original_vec_len: usize, // Original length of vec_ref before drain
654   tail_len: usize,         // Number of elements after the drain range in the original vec
655
656   _marker: PhantomData<&'a T>,
657}
658
659impl<'a, T: Zeroize> Iterator for Drain<'a, T> {
660   type Item = T;
661
662   fn next(&mut self) -> Option<T> {
663      if self.current_drain_iter_index < self.drain_end_index {
664         // SecureVec is already unlocked by the `drain` method.
665         unsafe {
666            let item_ptr = self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index);
667            let item = ptr::read(item_ptr);
668            self.current_drain_iter_index += 1;
669            Some(item)
670         }
671      } else {
672         None
673      }
674   }
675
676   fn size_hint(&self) -> (usize, Option<usize>) {
677      let remaining = self.drain_end_index - self.current_drain_iter_index;
678      (remaining, Some(remaining))
679   }
680}
681
682impl<'a, T: Zeroize> ExactSizeIterator for Drain<'a, T> {}
683
684impl<'a, T: Zeroize> Drop for Drain<'a, T> {
685   fn drop(&mut self) {
686      unsafe {
687         // The vec_ref's memory is currently unlocked.
688         if mem::needs_drop::<T>() {
689            let mut current_ptr =
690               self.vec_ref.ptr.as_ptr().add(self.current_drain_iter_index) as *mut T;
691            let end_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index) as *mut T;
692            while current_ptr < end_ptr {
693               ptr::drop_in_place(current_ptr);
694               current_ptr = current_ptr.add(1);
695            }
696         }
697
698         let hole_dst_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_start_index) as *mut T;
699         let tail_src_ptr = self.vec_ref.ptr.as_ptr().add(self.drain_end_index) as *mut T;
700
701         if self.tail_len > 0 {
702            ptr::copy(tail_src_ptr, hole_dst_ptr, self.tail_len);
703         }
704
705         // The new length of the vector.
706         let new_len = self.drain_start_index + self.tail_len;
707
708         // Process the memory region that is no longer part of the active vector's content.
709         // This region is from `vec_ref.ptr + new_len` up to `vec_ref.ptr + original_vec_len`.
710         // It contains:
711         //    a) Original data of the latter part of the drained slice (if not overwritten by tail).
712         //       These were dropped in step 1 if T:Drop.
713         //    b) Original data of the tail items (which have now been copied).
714         //       These need to be dropped if T:Drop, as ptr::copy doesn't drop the source.
715         // After any necessary drops, this entire region must be zeroized.
716
717         let mut current_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(new_len) as *mut T;
718         let end_cleanup_ptr = self.vec_ref.ptr.as_ptr().add(self.original_vec_len) as *mut T;
719
720         // Determine the start of the original tail's memory region
721         let original_tail_start_ptr_val = tail_src_ptr as usize;
722
723         while current_cleanup_ptr < end_cleanup_ptr {
724            if mem::needs_drop::<T>() {
725               let current_ptr_val = current_cleanup_ptr as usize;
726               let original_tail_end_ptr_val =
727                  original_tail_start_ptr_val + self.tail_len * mem::size_of::<T>();
728
729               if current_ptr_val >= original_tail_start_ptr_val
730                  && current_ptr_val < original_tail_end_ptr_val
731               {
732                  // This element was part of the original tail. ptr::copy moved its value.
733                  // The original instance here needs to be dropped.
734                  ptr::drop_in_place(current_cleanup_ptr);
735               }
736               // Else, it was part of the drained range (not covered by tail move).
737               // If it needed dropping, it was handled in step 1.
738            }
739
740            // Zeroize the memory of this element.
741            (*current_cleanup_ptr).zeroize();
742            current_cleanup_ptr = current_cleanup_ptr.add(1);
743         }
744
745         // Update the SecureVec's length.
746         self.vec_ref.len = new_len;
747
748         // Relock the SecureVec's memory.
749         self.vec_ref.lock_memory();
750      }
751   }
752}
753
754// Helper function to resolve RangeBounds to (start, end) indices
755fn resolve_range_indices<R: RangeBounds<usize>>(range: R, len: usize) -> (usize, usize) {
756   let start_bound = range.start_bound();
757   let end_bound = range.end_bound();
758
759   let start = match start_bound {
760      Bound::Included(&s) => s,
761      Bound::Excluded(&s) => s
762         .checked_add(1)
763         .unwrap_or_else(|| panic!("attempted to start drain at Excluded(usize::MAX)")),
764      Bound::Unbounded => 0,
765   };
766
767   let end = match end_bound {
768      Bound::Included(&e) => e
769         .checked_add(1)
770         .unwrap_or_else(|| panic!("attempted to end drain at Included(usize::MAX)")),
771      Bound::Excluded(&e) => e,
772      Bound::Unbounded => len,
773   };
774
775   if start > end {
776      panic!(
777         "drain range start ({}) must be less than or equal to end ({})",
778         start, end
779      );
780   }
781   if end > len {
782      panic!(
783         "drain range end ({}) out of bounds for slice of length {}",
784         end, len
785      );
786   }
787
788   (start, end)
789}
790
791#[cfg(all(test, feature = "std"))]
792mod tests {
793   use super::*;
794   use std::process::{Command, Stdio};
795   use std::sync::{Arc, Mutex};
796
797   #[test]
798   fn test_creation() {
799      let vec: Vec<u8> = vec![1, 2, 3];
800      let _ = SecureVec::from_vec(vec);
801      let _: SecureVec<u8> = SecureVec::new().unwrap();
802      let _: SecureVec<u8> = SecureVec::with_capacity(3).unwrap();
803   }
804
805   #[test]
806   fn lock_unlock() {
807      let secure: SecureVec<u8> = SecureVec::new().unwrap();
808      let size = secure.allocated_byte_size();
809      assert_eq!(size > 0, true);
810
811      let (decrypted, unlocked) = secure.unlock_memory();
812      assert!(decrypted);
813      assert!(unlocked);
814
815      let (encrypted, locked) = secure.lock_memory();
816      assert!(encrypted);
817      assert!(locked);
818
819      let secure: SecureVec<u8> = SecureVec::from_vec(vec![]).unwrap();
820      let size = secure.allocated_byte_size();
821      assert_eq!(size > 0, true);
822
823      let (decrypted, unlocked) = secure.unlock_memory();
824      assert!(decrypted);
825      assert!(unlocked);
826
827      let (encrypted, locked) = secure.lock_memory();
828      assert!(encrypted);
829      assert!(locked);
830
831      let secure: SecureVec<u8> = SecureVec::with_capacity(0).unwrap();
832      let size = secure.allocated_byte_size();
833      assert_eq!(size > 0, true);
834
835      let (decrypted, unlocked) = secure.unlock_memory();
836      assert!(decrypted);
837      assert!(unlocked);
838
839      let (encrypted, locked) = secure.lock_memory();
840      assert!(encrypted);
841      assert!(locked);
842   }
843
844   #[test]
845   fn test_thread_safety() {
846      let vec: Vec<u8> = vec![];
847      let secure = SecureVec::from_vec(vec).unwrap();
848      let secure = Arc::new(Mutex::new(secure));
849
850      let mut handles = Vec::new();
851      for i in 0..5u8 {
852         let secure_clone = secure.clone();
853         let handle = std::thread::spawn(move || {
854            let mut secure = secure_clone.lock().unwrap();
855            secure.push(i);
856         });
857         handles.push(handle);
858      }
859
860      for handle in handles {
861         handle.join().unwrap();
862      }
863
864      let sec = secure.lock().unwrap();
865      sec.slice_scope(|slice| {
866         assert_eq!(slice.len(), 5);
867      });
868   }
869
870   #[test]
871   fn test_clone() {
872      let vec: Vec<u8> = vec![1, 2, 3];
873      let secure1 = SecureVec::from_vec(vec).unwrap();
874      let _secure2 = secure1.clone();
875   }
876
877   #[test]
878   fn test_do_not_call_forget_on_drain() {
879      let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
880      let mut secure = SecureVec::from_vec(vec).unwrap();
881      let drain = secure.drain(..3);
882      core::mem::forget(drain);
883      // we can still use secure vec but its state is unreachable
884      secure.slice_scope(|secure| {
885         assert_eq!(secure.len(), 0);
886      });
887   }
888
889   #[test]
890   fn test_drain() {
891      let vec: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
892      let mut secure = SecureVec::from_vec(vec).unwrap();
893      let mut drain = secure.drain(..3);
894      assert_eq!(drain.next(), Some(1));
895      assert_eq!(drain.next(), Some(2));
896      assert_eq!(drain.next(), Some(3));
897      assert_eq!(drain.next(), None);
898      drop(drain);
899      secure.slice_scope(|secure| {
900         assert_eq!(secure.len(), 7);
901         assert_eq!(secure, &[4, 5, 6, 7, 8, 9, 10]);
902      });
903   }
904
905   #[cfg(feature = "serde")]
906   #[test]
907   fn test_secure_vec_serde() {
908      let vec: Vec<u8> = vec![1, 2, 3];
909      let secure = SecureVec::from_vec(vec).unwrap();
910      let json = serde_json::to_vec(&secure).expect("Serialization failed");
911      let deserialized: SecureVec<u8> =
912         serde_json::from_slice(&json).expect("Deserialization failed");
913      deserialized.slice_scope(|slice| {
914         assert_eq!(slice, &[1, 2, 3]);
915      });
916   }
917
918   #[test]
919   fn test_erase() {
920      let vec: Vec<u8> = vec![1, 2, 3];
921      let mut secure = SecureVec::from_vec(vec).unwrap();
922      secure.erase();
923      secure.unlock_scope(|secure| {
924         assert_eq!(secure.len, 0);
925         assert_eq!(secure.capacity, 3);
926      });
927
928      secure.push(1);
929      secure.push(2);
930      secure.push(3);
931      secure.unlock_scope(|secure| {
932         assert_eq!(secure[0], 1);
933         assert_eq!(secure[1], 2);
934         assert_eq!(secure[2], 3);
935         assert_eq!(secure.len, 3);
936         assert_eq!(secure.capacity, 3);
937      });
938   }
939
940   #[test]
941   fn test_push() {
942      let vec: Vec<u8> = Vec::new();
943      let mut secure = SecureVec::from_vec(vec).unwrap();
944      for i in 0..30 {
945         secure.push(i);
946      }
947   }
948
949   #[test]
950   fn test_index() {
951      let vec: Vec<u8> = vec![1, 2, 3];
952      let secure = SecureVec::from_vec(vec).unwrap();
953      secure.unlock_scope(|secure| {
954         assert_eq!(secure[0], 1);
955         assert_eq!(secure[1], 2);
956         assert_eq!(secure[2], 3);
957      });
958   }
959
960   #[test]
961   fn test_slice_scoped() {
962      let vec: Vec<u8> = vec![1, 2, 3];
963      let secure = SecureVec::from_vec(vec).unwrap();
964      secure.slice_scope(|slice| {
965         assert_eq!(slice, &[1, 2, 3]);
966      });
967   }
968
969   #[test]
970   fn test_slice_mut_scoped() {
971      let vec: Vec<u8> = vec![1, 2, 3];
972      let mut secure = SecureVec::from_vec(vec).unwrap();
973
974      secure.slice_mut_scope(|slice| {
975         slice[0] = 4;
976         assert_eq!(slice, &mut [4, 2, 3]);
977      });
978
979      secure.slice_scope(|slice| {
980         assert_eq!(slice, &[4, 2, 3]);
981      });
982   }
983
984   #[test]
985   fn test_iter_scoped() {
986      let vec: Vec<u8> = vec![1, 2, 3];
987      let secure = SecureVec::from_vec(vec).unwrap();
988      let sum: u8 = secure.iter_scope(|iter| iter.map(|&x| x).sum());
989
990      assert_eq!(sum, 6);
991
992      let secure: SecureVec<u8> = SecureVec::with_capacity(3).unwrap();
993      let sum: u8 = secure.iter_scope(|iter| iter.map(|&x| x).sum());
994
995      assert_eq!(sum, 0);
996   }
997
998   #[test]
999   fn test_iter_mut_scoped() {
1000      let vec: Vec<u8> = vec![1, 2, 3];
1001      let mut secure = SecureVec::from_vec(vec).unwrap();
1002      secure.iter_mut_scope(|iter| {
1003         for elem in iter {
1004            *elem += 1;
1005         }
1006      });
1007
1008      secure.slice_scope(|slice| {
1009         assert_eq!(slice, &[2, 3, 4]);
1010      });
1011   }
1012
1013   #[test]
1014   fn test_index_should_fail_when_locked() {
1015      let arg = "CRASH_TEST_SECUREVEC_LOCKED";
1016
1017      if std::env::args().any(|a| a == arg) {
1018         let vec: Vec<u8> = vec![1, 2, 3];
1019         let secure = SecureVec::from_vec(vec).unwrap();
1020         let _value = core::hint::black_box(secure[0]);
1021
1022         std::process::exit(1);
1023      }
1024
1025      let child = Command::new(std::env::current_exe().unwrap())
1026         .arg("vec::tests::test_index_should_fail_when_locked")
1027         .arg(arg)
1028         .arg("--nocapture")
1029         .stdout(Stdio::piped())
1030         .stderr(Stdio::piped())
1031         .spawn()
1032         .expect("Failed to spawn child process");
1033
1034      let output = child.wait_with_output().expect("Failed to wait on child");
1035      let status = output.status;
1036
1037      assert!(
1038         !status.success(),
1039         "Process exited successfully with code {:?}, but it should have crashed.",
1040         status.code()
1041      );
1042
1043      #[cfg(unix)]
1044      {
1045         use std::os::unix::process::ExitStatusExt;
1046         let signal = status
1047            .signal()
1048            .expect("Process was not terminated by a signal on Unix.");
1049         assert!(
1050            signal == libc::SIGSEGV || signal == libc::SIGBUS,
1051            "Process terminated with unexpected signal: {}",
1052            signal
1053         );
1054         println!(
1055            "Test passed: Process correctly terminated with signal {}.",
1056            signal
1057         );
1058      }
1059
1060      #[cfg(windows)]
1061      {
1062         const STATUS_ACCESS_VIOLATION: i32 = 0xC0000005_u32 as i32;
1063         assert_eq!(
1064            status.code(),
1065            Some(STATUS_ACCESS_VIOLATION),
1066            "Process exited with unexpected code: {:x?}. Expected STATUS_ACCESS_VIOLATION.",
1067            status.code()
1068         );
1069         eprintln!("Test passed: Process correctly terminated with STATUS_ACCESS_VIOLATION.");
1070      }
1071   }
1072}