runtime/
boxed.rs

1// Copyright 2020-2021 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::types::*;
5use zeroize::Zeroize;
6
7use core::{
8    cell::Cell,
9    fmt::{self, Debug},
10    mem,
11    ptr::NonNull,
12    slice,
13};
14
15use libsodium_sys::{
16    sodium_allocarray, sodium_free, sodium_init, sodium_mlock, sodium_mprotect_noaccess, sodium_mprotect_readonly,
17    sodium_mprotect_readwrite,
18};
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21enum Prot {
22    NoAccess,
23    ReadOnly,
24    ReadWrite,
25}
26
27type RefCount = u8;
28
29/// A protected piece of memory.
30#[derive(Eq)]
31pub(crate) struct Boxed<T: Bytes> {
32    // the pointer to the underlying protected memory
33    ptr: NonNull<T>,
34    // The number of elements of type `T` that can be stored in the pointer.
35    len: usize,
36    // the current protection level of the data.
37    prot: Cell<Prot>,
38    // The number of current borrows of this pointer.
39    refs: Cell<RefCount>,
40}
41
42impl<T: Bytes> Boxed<T> {
43    pub(crate) fn new<F>(len: usize, init: F) -> Self
44    where
45        F: FnOnce(&mut Self),
46    {
47        let mut boxed = Self::new_unlocked(len);
48        unsafe { lock_memory(boxed.ptr.as_mut(), len) };
49
50        assert!(
51            boxed.ptr != core::ptr::NonNull::dangling(),
52            "Make sure pointer isn't dangling"
53        );
54        assert!(boxed.len == len);
55
56        init(&mut boxed);
57
58        boxed.lock();
59
60        boxed
61    }
62
63    #[allow(dead_code)]
64    pub(crate) fn try_new<R, E, F>(len: usize, init: F) -> Result<Self, E>
65    where
66        F: FnOnce(&mut Self) -> Result<R, E>,
67    {
68        let mut boxed = Self::new_unlocked(len);
69
70        assert!(
71            boxed.ptr != core::ptr::NonNull::dangling(),
72            "Make sure pointer isn't dangling"
73        );
74        assert!(boxed.len == len);
75
76        let res = init(&mut boxed);
77
78        boxed.lock();
79
80        res.map(|_| boxed)
81    }
82
83    pub(crate) fn len(&self) -> usize {
84        self.len
85    }
86
87    pub(crate) fn is_empty(&self) -> bool {
88        self.len == 0
89    }
90
91    pub(crate) fn size(&self) -> usize {
92        self.len * T::size()
93    }
94
95    pub(crate) fn unlock(&self) -> &Self {
96        self.retain(Prot::ReadOnly);
97        self
98    }
99
100    pub(crate) fn unlock_mut(&mut self) -> &mut Self {
101        self.retain(Prot::ReadWrite);
102        self
103    }
104
105    pub(crate) fn lock(&self) {
106        self.release()
107    }
108
109    #[allow(dead_code)]
110    pub(crate) fn as_ref(&self) -> &T {
111        assert!(!self.is_empty(), "Attempted to dereference a zero-length pointer");
112
113        assert!(self.prot.get() != Prot::NoAccess, "May not call Boxed while locked");
114
115        unsafe { self.ptr.as_ref() }
116    }
117
118    pub(crate) fn as_mut(&mut self) -> &mut T {
119        assert!(!self.is_empty(), "Attempted to dereference a zero-length pointer");
120
121        assert!(
122            self.prot.get() == Prot::ReadWrite,
123            "May not call Boxed unless mutably unlocked"
124        );
125
126        unsafe { self.ptr.as_mut() }
127    }
128
129    pub(crate) fn as_slice(&self) -> &[T] {
130        assert!(self.prot.get() != Prot::NoAccess, "May not call Boxed while locked");
131
132        unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
133    }
134
135    pub(crate) fn as_mut_slice(&mut self) -> &mut [T] {
136        assert!(
137            self.prot.get() == Prot::ReadWrite,
138            "May not call Boxed unless mutably unlocked"
139        );
140
141        unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
142    }
143
144    fn new_unlocked(len: usize) -> Self {
145        if unsafe { sodium_init() == -1 } {
146            panic!("Failed to initialize libsodium")
147        }
148
149        let ptr = NonNull::new(unsafe { sodium_allocarray(len, mem::size_of::<T>()) as *mut _ })
150            .expect("Failed to allocate memory");
151
152        Self {
153            ptr,
154            len,
155            prot: Cell::new(Prot::ReadWrite),
156            refs: Cell::new(1),
157        }
158    }
159
160    fn retain(&self, prot: Prot) {
161        let refs = self.refs.get();
162
163        if refs == 0 {
164            assert!(prot != Prot::NoAccess, "Must retain readably or writably");
165
166            self.prot.set(prot);
167            mprotect(self.ptr.as_ptr(), prot);
168        } else {
169            assert!(
170                Prot::NoAccess != self.prot.get(),
171                "Out-of-order retain/release detected"
172            );
173            assert!(
174                Prot::ReadWrite != self.prot.get(),
175                "Cannot unlock mutably more than once"
176            );
177            assert!(Prot::ReadOnly == prot, "Cannot unlock mutably while unlocked immutably");
178        }
179
180        match refs.checked_add(1) {
181            Some(v) => self.refs.set(v),
182            None if self.is_locked() => panic!("Out-of-order retain/release detected"),
183            None => panic!("Retained too many times"),
184        };
185    }
186
187    fn release(&self) {
188        assert!(self.refs.get() != 0, "Releases exceeded retains");
189
190        assert!(
191            self.prot.get() != Prot::NoAccess,
192            "Releasing memory that's already locked"
193        );
194
195        let refs = self.refs.get().wrapping_sub(1);
196
197        self.refs.set(refs);
198
199        if refs == 0 {
200            mprotect(self.ptr.as_ptr(), Prot::NoAccess);
201            self.prot.set(Prot::NoAccess);
202        }
203    }
204
205    fn is_locked(&self) -> bool {
206        self.prot.get() == Prot::NoAccess
207    }
208
209    #[cfg(test)]
210    #[allow(dead_code)]
211    /// Returns the address of the pointer to the data
212    pub fn get_ptr_address(&self) -> usize {
213        self.ptr.as_ptr() as *const _ as usize
214    }
215}
216
217impl<T: Bytes + Randomized> Boxed<T> {
218    #[allow(dead_code)]
219    pub(crate) fn random(len: usize) -> Self {
220        Self::new(len, |b| b.as_mut_slice().randomize())
221    }
222}
223
224impl<T: Bytes + Zeroed> Boxed<T> {
225    #[allow(dead_code)]
226    pub(crate) fn zero(len: usize) -> Self {
227        Self::new(len, |b| b.as_mut_slice().zero())
228    }
229}
230
231// This may create undefined behaviour if not used correctly
232// Zeroes out the memory and configuration
233impl<T: Bytes> Zeroize for Boxed<T> {
234    fn zeroize(&mut self) {
235        self.unlock_mut();
236        self.as_mut_slice().zero();
237        self.lock();
238        self.refs.set(0);
239        self.prot.set(Prot::NoAccess);
240        self.len = 0;
241    }
242}
243
244impl<T: Bytes> Drop for Boxed<T> {
245    fn drop(&mut self) {
246        extern crate std;
247
248        use std::thread;
249
250        if !thread::panicking() {
251            assert!(self.refs.get() == 0, "Retains exceeded releases");
252
253            assert!(self.prot.get() == Prot::NoAccess, "Dropped secret was still accessible");
254        }
255
256        unsafe { free(self.ptr.as_mut()) }
257    }
258}
259
260impl<T: Bytes> Debug for Boxed<T> {
261    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
262        write!(fmt, "{{ size: {}, hidden }}", self.size())
263    }
264}
265
266impl<T: Bytes> Clone for Boxed<T> {
267    fn clone(&self) -> Self {
268        Self::new(self.len, |b| {
269            b.as_mut_slice().copy_from_slice(self.unlock().as_slice());
270            self.lock();
271        })
272    }
273}
274
275impl<T: Bytes + ConstEq> PartialEq for Boxed<T> {
276    fn eq(&self, other: &Self) -> bool {
277        if self.len != other.len {
278            return false;
279        }
280
281        let lhs = self.unlock().as_slice();
282        let rhs = other.unlock().as_slice();
283
284        let ret = lhs.const_eq(rhs);
285
286        self.lock();
287        other.lock();
288
289        ret
290    }
291}
292
293impl<T: Bytes + Zeroed> From<&mut T> for Boxed<T> {
294    fn from(data: &mut T) -> Self {
295        Self::new(1, |b| unsafe { data.copy_and_zero(b.as_mut()) })
296    }
297}
298
299impl<T: Bytes + Zeroed> From<&mut [T]> for Boxed<T> {
300    fn from(data: &mut [T]) -> Self {
301        Self::new(data.len(), |b| unsafe { data.copy_and_zero(b.as_mut_slice()) })
302    }
303}
304
305unsafe impl<T: Bytes + Send> Send for Boxed<T> {}
306unsafe impl<T: Bytes + Sync> Sync for Boxed<T> {}
307
308fn mprotect<T>(ptr: *mut T, prot: Prot) {
309    if !match prot {
310        Prot::NoAccess => unsafe { sodium_mprotect_noaccess(ptr as *mut _) == 0 },
311        Prot::ReadOnly => unsafe { sodium_mprotect_readonly(ptr as *mut _) == 0 },
312        Prot::ReadWrite => unsafe { sodium_mprotect_readwrite(ptr as *mut _) == 0 },
313    } {
314        panic!("Error setting memory protection to {:?}", prot);
315    }
316}
317
318pub(crate) unsafe fn free<T>(ptr: *mut T) {
319    sodium_free(ptr as *mut _)
320}
321
322pub(crate) unsafe fn lock_memory<T>(ptr: *mut T, len: usize) {
323    sodium_mlock(ptr as *mut _, len);
324}
325
326#[cfg(test)]
327mod test {
328    extern crate alloc;
329
330    use alloc::vec;
331
332    use super::*;
333    use libsodium_sys::randombytes_buf;
334
335    #[test]
336    fn boxed_zeroize() {
337        let mut boxed = Boxed::<u8>::random(4);
338        let ptr = unsafe { core::slice::from_raw_parts(boxed.ptr.as_ptr(), 4) };
339        boxed.unlock();
340        assert_ne!(ptr, [0u8; 4]);
341        boxed.lock();
342
343        boxed.zeroize();
344
345        boxed.unlock();
346        assert_eq!(ptr, [0u8; 4]);
347        boxed.lock();
348    }
349
350    #[test]
351    fn test_init_with_garbage() {
352        let boxed = Boxed::<u8>::new(4, |_| {});
353        let unboxed = boxed.unlock().as_slice();
354
355        let garbage = unsafe {
356            let garb_ptr = sodium_allocarray(1, mem::size_of::<u8>()) as *mut u8;
357            let garb_byte = *garb_ptr;
358
359            free(garb_ptr);
360
361            vec![garb_byte; unboxed.len()]
362        };
363
364        assert_ne!(garbage, vec![0; garbage.len()]);
365        assert_eq!(unboxed, &garbage[..]);
366
367        boxed.lock();
368    }
369
370    #[test]
371    fn test_custom_init() {
372        let boxed = Boxed::<u8>::new(1, |secret| {
373            secret.as_mut_slice().copy_from_slice(b"\x04");
374        });
375
376        assert_eq!(boxed.unlock().as_slice(), [0x04]);
377        boxed.lock();
378    }
379
380    #[test]
381    fn test_init_with_zero() {
382        let boxed = Boxed::<u8>::zero(6);
383
384        assert_eq!(boxed.unlock().as_slice(), [0, 0, 0, 0, 0, 0]);
385
386        boxed.lock();
387    }
388
389    #[test]
390    fn test_init_with_values() {
391        let mut value = [8u64];
392        let boxed = Boxed::from(&mut value[..]);
393
394        assert_eq!(value, [0]);
395        assert_eq!(boxed.unlock().as_slice(), [8]);
396
397        boxed.lock();
398    }
399
400    #[allow(clippy::redundant_clone)]
401    #[test]
402    fn test_eq() {
403        let boxed_a = Boxed::<u8>::random(1);
404        let boxed_b = boxed_a.clone();
405
406        assert_eq!(boxed_a, boxed_b);
407        assert_eq!(boxed_b, boxed_a);
408
409        let boxed_a = Boxed::<u8>::random(16);
410        let boxed_b = Boxed::<u8>::random(16);
411
412        assert_ne!(boxed_a, boxed_b);
413        assert_ne!(boxed_b, boxed_a);
414
415        let boxed_b = Boxed::<u8>::random(12);
416
417        assert_ne!(boxed_a, boxed_b);
418        assert_ne!(boxed_b, boxed_a);
419    }
420
421    #[test]
422    fn test_refs() {
423        let mut boxed = Boxed::<u8>::zero(8);
424
425        assert_eq!(0, boxed.refs.get());
426
427        let _ = boxed.unlock();
428        let _ = boxed.unlock();
429
430        assert_eq!(2, boxed.refs.get());
431
432        boxed.lock();
433        boxed.lock();
434
435        assert_eq!(0, boxed.refs.get());
436
437        let _ = boxed.unlock_mut();
438
439        assert_eq!(1, boxed.refs.get());
440
441        boxed.lock();
442
443        assert_eq!(0, boxed.refs.get());
444    }
445
446    #[test]
447    fn test_ref_overflow() {
448        let boxed = Boxed::<u8>::zero(8);
449
450        for _ in 0..u8::max_value() {
451            let _ = boxed.unlock();
452        }
453
454        for _ in 0..u8::max_value() {
455            boxed.lock()
456        }
457    }
458
459    #[test]
460    fn test_random_borrow_amounts() {
461        let boxed = Boxed::<u8>::zero(1);
462        let mut counter = 0u8;
463
464        unsafe {
465            randombytes_buf(
466                counter.as_mut_bytes().as_mut_ptr() as *mut _,
467                counter.as_mut_bytes().len(),
468            );
469        }
470
471        for _ in 0..counter {
472            let _ = boxed.unlock();
473        }
474
475        for _ in 0..counter {
476            boxed.lock()
477        }
478    }
479
480    #[test]
481    fn test_threading() {
482        extern crate std;
483
484        use std::{sync::mpsc, thread};
485
486        let (tx, rx) = mpsc::channel();
487
488        let ch = thread::spawn(move || {
489            let boxed = Boxed::<u64>::random(1);
490            let val = boxed.unlock().as_slice().to_vec();
491
492            tx.send((boxed, val)).expect("failed to send via channel");
493        });
494
495        let (boxed, val) = rx.recv().expect("failed to read from channel");
496
497        assert_eq!(Prot::ReadOnly, boxed.prot.get());
498        assert_eq!(val, boxed.as_slice());
499
500        ch.join().expect("child thread terminated.");
501        boxed.lock();
502    }
503
504    #[test]
505    #[should_panic(expected = "Retained too many times")]
506    fn test_overflow_refs() {
507        let boxed = Boxed::<[u8; 4]>::zero(4);
508
509        for _ in 0..=u8::max_value() {
510            let _ = boxed.unlock();
511        }
512
513        for _ in 0..boxed.refs.get() {
514            boxed.lock()
515        }
516    }
517
518    #[test]
519    #[should_panic(expected = "Out-of-order retain/release detected")]
520    fn test_out_of_order() {
521        let boxed = Boxed::<u8>::zero(3);
522
523        boxed.refs.set(boxed.refs.get().wrapping_sub(1));
524        boxed.prot.set(Prot::NoAccess);
525
526        boxed.retain(Prot::ReadOnly);
527    }
528
529    #[test]
530    #[should_panic(expected = "Attempted to dereference a zero-length pointer")]
531    fn test_zero_length() {
532        let boxed = Boxed::<u8>::zero(0);
533
534        let _ = boxed.as_ref();
535    }
536
537    #[test]
538    #[should_panic(expected = "Cannot unlock mutably more than once")]
539    fn test_multiple_writers() {
540        let mut boxed = Boxed::<u64>::zero(1);
541
542        let _ = boxed.unlock_mut();
543        let _ = boxed.unlock_mut();
544    }
545
546    #[test]
547    #[should_panic(expected = "Releases exceeded retains")]
548    fn test_release_vs_retain() {
549        Boxed::<u64>::zero(2).lock();
550    }
551}