russh_cryptovec/
cryptovec.rs

1use std::fmt::Debug;
2use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo};
3
4use crate::platform::{self, memset, mlock, munlock};
5
6/// A buffer which zeroes its memory on `.clear()`, `.resize()`, and
7/// reallocations, to avoid copying secrets around.
8pub struct CryptoVec {
9    p: *mut u8, // `pub(crate)` allows access from platform modules
10    size: usize,
11    capacity: usize,
12}
13
14impl Debug for CryptoVec {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        if self.size == 0 {
17            return f.write_str("<empty>");
18        }
19        write!(f, "<{:?}>", self.size)
20    }
21}
22
23impl Unpin for CryptoVec {}
24unsafe impl Send for CryptoVec {}
25unsafe impl Sync for CryptoVec {}
26
27// Common traits implementations
28impl AsRef<[u8]> for CryptoVec {
29    fn as_ref(&self) -> &[u8] {
30        self.deref()
31    }
32}
33
34impl AsMut<[u8]> for CryptoVec {
35    fn as_mut(&mut self) -> &mut [u8] {
36        self.deref_mut()
37    }
38}
39
40impl Deref for CryptoVec {
41    type Target = [u8];
42    fn deref(&self) -> &[u8] {
43        unsafe { std::slice::from_raw_parts(self.p, self.size) }
44    }
45}
46
47impl DerefMut for CryptoVec {
48    fn deref_mut(&mut self) -> &mut [u8] {
49        unsafe { std::slice::from_raw_parts_mut(self.p, self.size) }
50    }
51}
52
53impl From<String> for CryptoVec {
54    fn from(e: String) -> Self {
55        CryptoVec::from(e.into_bytes())
56    }
57}
58
59impl From<&str> for CryptoVec {
60    fn from(e: &str) -> Self {
61        CryptoVec::from(e.as_bytes())
62    }
63}
64
65impl From<&[u8]> for CryptoVec {
66    fn from(e: &[u8]) -> Self {
67        CryptoVec::from_slice(e)
68    }
69}
70
71impl From<Vec<u8>> for CryptoVec {
72    fn from(e: Vec<u8>) -> Self {
73        let mut c = CryptoVec::new_zeroed(e.len());
74        c.clone_from_slice(&e[..]);
75        c
76    }
77}
78
79// Indexing implementations
80impl Index<RangeFrom<usize>> for CryptoVec {
81    type Output = [u8];
82    fn index(&self, index: RangeFrom<usize>) -> &[u8] {
83        self.deref().index(index)
84    }
85}
86impl Index<RangeTo<usize>> for CryptoVec {
87    type Output = [u8];
88    fn index(&self, index: RangeTo<usize>) -> &[u8] {
89        self.deref().index(index)
90    }
91}
92impl Index<Range<usize>> for CryptoVec {
93    type Output = [u8];
94    fn index(&self, index: Range<usize>) -> &[u8] {
95        self.deref().index(index)
96    }
97}
98impl Index<RangeFull> for CryptoVec {
99    type Output = [u8];
100    fn index(&self, _: RangeFull) -> &[u8] {
101        self.deref()
102    }
103}
104
105impl IndexMut<RangeFull> for CryptoVec {
106    fn index_mut(&mut self, _: RangeFull) -> &mut [u8] {
107        self.deref_mut()
108    }
109}
110impl IndexMut<RangeFrom<usize>> for CryptoVec {
111    fn index_mut(&mut self, index: RangeFrom<usize>) -> &mut [u8] {
112        self.deref_mut().index_mut(index)
113    }
114}
115impl IndexMut<RangeTo<usize>> for CryptoVec {
116    fn index_mut(&mut self, index: RangeTo<usize>) -> &mut [u8] {
117        self.deref_mut().index_mut(index)
118    }
119}
120impl IndexMut<Range<usize>> for CryptoVec {
121    fn index_mut(&mut self, index: Range<usize>) -> &mut [u8] {
122        self.deref_mut().index_mut(index)
123    }
124}
125
126impl Index<usize> for CryptoVec {
127    type Output = u8;
128    fn index(&self, index: usize) -> &u8 {
129        self.deref().index(index)
130    }
131}
132
133// IO-related implementation
134impl std::io::Write for CryptoVec {
135    fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
136        self.extend(buf);
137        Ok(buf.len())
138    }
139
140    fn flush(&mut self) -> Result<(), std::io::Error> {
141        Ok(())
142    }
143}
144
145// Default implementation
146impl Default for CryptoVec {
147    fn default() -> Self {
148        CryptoVec {
149            p: std::ptr::NonNull::dangling().as_ptr(),
150            size: 0,
151            capacity: 0,
152        }
153    }
154}
155
156impl CryptoVec {
157    /// Creates a new `CryptoVec`.
158    pub fn new() -> CryptoVec {
159        CryptoVec::default()
160    }
161
162    /// Creates a new `CryptoVec` with `n` zeros.
163    pub fn new_zeroed(size: usize) -> CryptoVec {
164        unsafe {
165            let capacity = size.next_power_of_two();
166            let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1);
167            let p = std::alloc::alloc_zeroed(layout);
168            let _ = mlock(p, capacity);
169            CryptoVec { p, capacity, size }
170        }
171    }
172
173    /// Creates a new `CryptoVec` with capacity `capacity`.
174    pub fn with_capacity(capacity: usize) -> CryptoVec {
175        unsafe {
176            let capacity = capacity.next_power_of_two();
177            let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1);
178            let p = std::alloc::alloc_zeroed(layout);
179            let _ = mlock(p, capacity);
180            CryptoVec {
181                p,
182                capacity,
183                size: 0,
184            }
185        }
186    }
187
188    /// Length of this `CryptoVec`.
189    ///
190    /// ```
191    /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0)
192    /// ```
193    pub fn len(&self) -> usize {
194        self.size
195    }
196
197    /// Returns `true` if and only if this CryptoVec is empty.
198    ///
199    /// ```
200    /// assert!(russh_cryptovec::CryptoVec::new().is_empty())
201    /// ```
202    pub fn is_empty(&self) -> bool {
203        self.len() == 0
204    }
205
206    /// Resize this CryptoVec, appending zeros at the end. This may
207    /// perform at most one reallocation, overwriting the previous
208    /// version with zeros.
209    pub fn resize(&mut self, size: usize) {
210        if size <= self.capacity && size > self.size {
211            // If this is an expansion, just resize.
212            self.size = size
213        } else if size <= self.size {
214            // If this is a truncation, resize and erase the extra memory.
215            unsafe {
216                memset(self.p.add(size), 0, self.size - size);
217            }
218            self.size = size;
219        } else {
220            // realloc ! and erase the previous memory.
221            unsafe {
222                let next_capacity = size.next_power_of_two();
223                let old_ptr = self.p;
224                let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1);
225                self.p = std::alloc::alloc_zeroed(next_layout);
226                let _ = mlock(self.p, next_capacity);
227
228                if self.capacity > 0 {
229                    std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size);
230                    for i in 0..self.size {
231                        std::ptr::write_volatile(old_ptr.add(i), 0)
232                    }
233                    let _ = munlock(old_ptr, self.capacity);
234                    let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1);
235                    std::alloc::dealloc(old_ptr, layout);
236                }
237
238                if self.p.is_null() {
239                    #[allow(clippy::panic)]
240                    {
241                        panic!("Realloc failed, pointer = {:?} {:?}", self, size)
242                    }
243                } else {
244                    self.capacity = next_capacity;
245                    self.size = size;
246                }
247            }
248        }
249    }
250
251    /// Clear this CryptoVec (retaining the memory).
252    ///
253    /// ```
254    /// let mut v = russh_cryptovec::CryptoVec::new();
255    /// v.extend(b"blabla");
256    /// v.clear();
257    /// assert!(v.is_empty())
258    /// ```
259    pub fn clear(&mut self) {
260        self.resize(0);
261    }
262
263    /// Append a new byte at the end of this CryptoVec.
264    pub fn push(&mut self, s: u8) {
265        let size = self.size;
266        self.resize(size + 1);
267        unsafe { *self.p.add(size) = s }
268    }
269
270    /// Read `n_bytes` from `r`, and append them at the end of this
271    /// `CryptoVec`. Returns the number of bytes read (and appended).
272    pub fn read<R: std::io::Read>(
273        &mut self,
274        n_bytes: usize,
275        mut r: R,
276    ) -> Result<usize, std::io::Error> {
277        let cur_size = self.size;
278        self.resize(cur_size + n_bytes);
279        let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) };
280        // Resize the buffer to its appropriate size.
281        match r.read(s) {
282            Ok(n) => {
283                self.resize(cur_size + n);
284                Ok(n)
285            }
286            Err(e) => {
287                self.resize(cur_size);
288                Err(e)
289            }
290        }
291    }
292
293    /// Write all this CryptoVec to the provided `Write`. Returns the
294    /// number of bytes actually written.
295    ///
296    /// ```
297    /// let mut v = russh_cryptovec::CryptoVec::new();
298    /// v.extend(b"blabla");
299    /// let mut s = std::io::stdout();
300    /// v.write_all_from(0, &mut s).unwrap();
301    /// ```
302    pub fn write_all_from<W: std::io::Write>(
303        &self,
304        offset: usize,
305        mut w: W,
306    ) -> Result<usize, std::io::Error> {
307        assert!(offset < self.size);
308        // if we're past this point, self.p cannot be null.
309        unsafe {
310            let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset);
311            w.write(s)
312        }
313    }
314
315    /// Resize this CryptoVec, returning a mutable borrow to the extra bytes.
316    ///
317    /// ```
318    /// let mut v = russh_cryptovec::CryptoVec::new();
319    /// v.resize_mut(4).clone_from_slice(b"test");
320    /// ```
321    pub fn resize_mut(&mut self, n: usize) -> &mut [u8] {
322        let size = self.size;
323        self.resize(size + n);
324        unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) }
325    }
326
327    /// Append a slice at the end of this CryptoVec.
328    ///
329    /// ```
330    /// let mut v = russh_cryptovec::CryptoVec::new();
331    /// v.extend(b"test");
332    /// ```
333    pub fn extend(&mut self, s: &[u8]) {
334        let size = self.size;
335        self.resize(size + s.len());
336        unsafe {
337            std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len());
338        }
339    }
340
341    /// Create a `CryptoVec` from a slice
342    ///
343    /// ```
344    /// russh_cryptovec::CryptoVec::from_slice(b"test");
345    /// ```
346    pub fn from_slice(s: &[u8]) -> CryptoVec {
347        let mut v = CryptoVec::new();
348        v.resize(s.len());
349        unsafe {
350            std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len());
351        }
352        v
353    }
354}
355
356impl Clone for CryptoVec {
357    fn clone(&self) -> Self {
358        let mut v = Self::new();
359        v.extend(self);
360        v
361    }
362}
363
364// Drop implementation
365impl Drop for CryptoVec {
366    fn drop(&mut self) {
367        if self.capacity > 0 {
368            unsafe {
369                for i in 0..self.size {
370                    std::ptr::write_volatile(self.p.add(i), 0);
371                }
372                let _ = platform::munlock(self.p, self.capacity);
373                let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1);
374                std::alloc::dealloc(self.p, layout);
375            }
376        }
377    }
378}
379
380#[cfg(test)]
381mod test {
382    use super::CryptoVec;
383
384    #[test]
385    fn test_new() {
386        let crypto_vec = CryptoVec::new();
387        assert_eq!(crypto_vec.size, 0);
388        assert_eq!(crypto_vec.capacity, 0);
389    }
390
391    #[test]
392    fn test_resize_expand() {
393        let mut crypto_vec = CryptoVec::new_zeroed(5);
394        crypto_vec.resize(10);
395        assert_eq!(crypto_vec.size, 10);
396        assert!(crypto_vec.capacity >= 10);
397        assert!(crypto_vec.iter().skip(5).all(|&x| x == 0)); // Ensure newly added elements are zeroed
398    }
399
400    #[test]
401    fn test_resize_shrink() {
402        let mut crypto_vec = CryptoVec::new_zeroed(10);
403        crypto_vec.resize(5);
404        assert_eq!(crypto_vec.size, 5);
405        // Ensure shrinking keeps the previous elements intact
406        assert_eq!(crypto_vec.len(), 5);
407    }
408
409    #[test]
410    fn test_push() {
411        let mut crypto_vec = CryptoVec::new();
412        crypto_vec.push(1);
413        crypto_vec.push(2);
414        assert_eq!(crypto_vec.size, 2);
415        assert_eq!(crypto_vec[0], 1);
416        assert_eq!(crypto_vec[1], 2);
417    }
418
419    #[test]
420    fn test_write_trait() {
421        use std::io::Write;
422
423        let mut crypto_vec = CryptoVec::new();
424        let bytes_written = crypto_vec.write(&[1, 2, 3]).unwrap();
425        assert_eq!(bytes_written, 3);
426        assert_eq!(crypto_vec.size, 3);
427        assert_eq!(crypto_vec.as_ref(), &[1, 2, 3]);
428    }
429
430    #[test]
431    fn test_as_ref_as_mut() {
432        let mut crypto_vec = CryptoVec::new_zeroed(5);
433        let slice_ref: &[u8] = crypto_vec.as_ref();
434        assert_eq!(slice_ref.len(), 5);
435        let slice_mut: &mut [u8] = crypto_vec.as_mut();
436        slice_mut[0] = 1;
437        assert_eq!(crypto_vec[0], 1);
438    }
439
440    #[test]
441    fn test_from_string() {
442        let input = String::from("hello");
443        let crypto_vec: CryptoVec = input.into();
444        assert_eq!(crypto_vec.as_ref(), b"hello");
445    }
446
447    #[test]
448    fn test_from_str() {
449        let input = "hello";
450        let crypto_vec: CryptoVec = input.into();
451        assert_eq!(crypto_vec.as_ref(), b"hello");
452    }
453
454    #[test]
455    fn test_from_byte_slice() {
456        let input = b"hello".as_slice();
457        let crypto_vec: CryptoVec = input.into();
458        assert_eq!(crypto_vec.as_ref(), b"hello");
459    }
460
461    #[test]
462    fn test_from_vec() {
463        let input = vec![1, 2, 3, 4];
464        let crypto_vec: CryptoVec = input.into();
465        assert_eq!(crypto_vec.as_ref(), &[1, 2, 3, 4]);
466    }
467
468    #[test]
469    fn test_index() {
470        let crypto_vec = CryptoVec::from(vec![1, 2, 3, 4, 5]);
471        assert_eq!(crypto_vec[0], 1);
472        assert_eq!(crypto_vec[4], 5);
473        assert_eq!(&crypto_vec[1..3], &[2, 3]);
474    }
475
476    #[test]
477    fn test_drop() {
478        let mut crypto_vec = CryptoVec::new_zeroed(10);
479        // Ensure vector is filled with non-zero data
480        crypto_vec.extend(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
481        drop(crypto_vec);
482
483        // Check that memory zeroing was done during the drop
484        // This part is more difficult to test directly since it involves
485        // private memory management. However, with Rust's unsafe features,
486        // it may be checked using tools like Valgrind or manual inspection.
487    }
488
489    #[test]
490    fn test_new_zeroed() {
491        let crypto_vec = CryptoVec::new_zeroed(10);
492        assert_eq!(crypto_vec.size, 10);
493        assert!(crypto_vec.capacity >= 10);
494        assert!(crypto_vec.iter().all(|&x| x == 0)); // Ensure all bytes are zeroed
495    }
496
497    #[test]
498    fn test_clear() {
499        let mut crypto_vec = CryptoVec::new();
500        crypto_vec.extend(b"blabla");
501        crypto_vec.clear();
502        assert!(crypto_vec.is_empty());
503    }
504
505    #[test]
506    fn test_extend() {
507        let mut crypto_vec = CryptoVec::new();
508        crypto_vec.extend(b"test");
509        assert_eq!(crypto_vec.as_ref(), b"test");
510    }
511
512    #[test]
513    fn test_write_all_from() {
514        let mut crypto_vec = CryptoVec::new();
515        crypto_vec.extend(b"blabla");
516
517        let mut output: Vec<u8> = Vec::new();
518        let written_size = crypto_vec.write_all_from(0, &mut output).unwrap();
519        assert_eq!(written_size, 6); // "blabla" has 6 bytes
520        assert_eq!(output, b"blabla");
521    }
522
523    #[test]
524    fn test_resize_mut() {
525        let mut crypto_vec = CryptoVec::new();
526        crypto_vec.resize_mut(4).clone_from_slice(b"test");
527        assert_eq!(crypto_vec.as_ref(), b"test");
528    }
529
530    // DocTests cannot be run on with wasm_bindgen_test
531    #[cfg(target_arch = "wasm32")]
532    mod wasm32 {
533        use wasm_bindgen_test::wasm_bindgen_test;
534
535        use super::*;
536
537        wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
538
539        #[wasm_bindgen_test]
540        fn test_push_u32_be() {
541            let mut crypto_vec = CryptoVec::new();
542            let value = 43554u32;
543            crypto_vec.push_u32_be(value);
544            assert_eq!(crypto_vec.len(), 4); // u32 is 4 bytes long
545            assert_eq!(crypto_vec.read_u32_be(0), value);
546        }
547
548        #[wasm_bindgen_test]
549        fn test_read_u32_be() {
550            let mut crypto_vec = CryptoVec::new();
551            let value = 99485710u32;
552            crypto_vec.push_u32_be(value);
553            assert_eq!(crypto_vec.read_u32_be(0), value);
554        }
555    }
556}