t_rust_less_lib/memguard/
words.rs

1use super::alloc;
2use super::memory;
3use capnp::message::{AllocationStrategy, Allocator, SUGGESTED_ALLOCATION_STRATEGY, SUGGESTED_FIRST_SEGMENT_WORDS};
4use capnp::Word;
5use log::warn;
6use std::convert::{AsMut, AsRef};
7use std::ops::{Deref, DerefMut};
8use std::ptr::{copy_nonoverlapping, NonNull};
9use std::slice;
10use std::sync::atomic::{AtomicIsize, Ordering};
11
12/// Strictly memory protected bytes contain sensitive data.
13///
14/// This implementation borrows a lot of code and ideas from:
15/// * https://crates.io/crates/memsec
16/// * https://crates.io/crates/secrets
17/// * https://download.libsodium.org/doc/memory_management
18///
19/// `secrets` is not good enough because it relies on libsodium which breaks the desired
20/// portability of this library (at least at the time of this writing).
21///
22/// `memsec` is not
23/// good enough because it focuses on protecting a generic type `T` which size is known at
24/// compile-time. In this library we are dealing with dynamic amounts of sensitive data and
25/// there is no point in securing a `Vec<u8>` via `memsec` ... all we would achieve is protecting
26/// the pointer to sensitive data in unsecured space.
27///
28pub struct SecretWords {
29  ptr: NonNull<Word>,
30  size: usize,
31  capacity: usize,
32  locks: AtomicIsize,
33}
34
35impl SecretWords {
36  /// Copy from slice of bytes.
37  ///
38  /// This is not a regular From implementation because the caller has to ensure that
39  /// the original bytes are zeroed out (or are already in some secured memspace.
40  /// This different signature should be a reminder of that.
41  pub fn from_secured(bytes: &[u8]) -> Self {
42    if bytes.len() % 8 != 0 {
43      warn!("Bytes not aligned to 8 bytes. Probably these are not the bytes you are looking for.");
44    }
45    unsafe {
46      let len = bytes.len() / 8;
47      let ptr = alloc::malloc(len * 8).cast();
48
49      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr() as *mut u8, len * 8);
50      alloc::mprotect(ptr, alloc::Prot::NoAccess);
51
52      SecretWords {
53        ptr,
54        size: len,
55        capacity: len,
56        locks: AtomicIsize::new(0),
57      }
58    }
59  }
60
61  pub fn with_capacity(capacity: usize) -> SecretWords {
62    unsafe {
63      let ptr = alloc::malloc(capacity * 8).cast();
64
65      alloc::mprotect(ptr, alloc::Prot::NoAccess);
66
67      SecretWords {
68        ptr,
69        size: 0,
70        capacity,
71        locks: AtomicIsize::new(0),
72      }
73    }
74  }
75
76  pub fn zeroed(size: usize) -> SecretWords {
77    unsafe {
78      let ptr = alloc::malloc(size * 8).cast();
79
80      memory::memzero(ptr.as_ptr() as *mut u8, size * 8);
81      alloc::mprotect(ptr, alloc::Prot::NoAccess);
82
83      SecretWords {
84        ptr,
85        size,
86        capacity: size,
87        locks: AtomicIsize::new(0),
88      }
89    }
90  }
91
92  pub fn is_empty(&self) -> bool {
93    self.size == 0
94  }
95
96  pub fn len(&self) -> usize {
97    self.size
98  }
99
100  pub fn capacity(&self) -> usize {
101    self.capacity
102  }
103
104  pub fn borrow(&self) -> Ref<'_> {
105    self.lock_read();
106    Ref { words: self }
107  }
108
109  pub fn borrow_mut(&mut self) -> RefMut<'_> {
110    self.lock_write();
111    RefMut { words: self }
112  }
113
114  pub fn locks(&self) -> isize {
115    self.locks.load(Ordering::Relaxed)
116  }
117
118  fn lock_read(&self) {
119    let locks = self.locks.fetch_add(1, Ordering::Relaxed);
120
121    assert!(locks >= 0);
122
123    if locks == 0 {
124      unsafe {
125        alloc::mprotect(self.ptr, alloc::Prot::ReadOnly);
126      }
127    }
128  }
129
130  fn unlock_read(&self) {
131    let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
132
133    assert!(locks > 0);
134
135    if locks == 1 {
136      unsafe {
137        alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
138      }
139    }
140  }
141
142  fn lock_write(&mut self) {
143    let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
144
145    assert!(locks == 0);
146
147    unsafe {
148      alloc::mprotect(self.ptr, alloc::Prot::ReadWrite);
149    }
150  }
151
152  fn unlock_write(&mut self) {
153    let locks = self.locks.fetch_add(1, Ordering::Relaxed);
154
155    assert!(locks == -1);
156
157    unsafe {
158      alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
159    }
160  }
161
162  /// Internal use only.
163  /// This will take a write-lock and never undo it until the SecretWords are dropped.
164  fn as_mut_ptr(&mut self) -> *mut Word {
165    self.lock_write();
166
167    self.ptr.as_ptr()
168  }
169}
170
171unsafe impl Send for SecretWords {}
172
173unsafe impl Sync for SecretWords {}
174
175impl Drop for SecretWords {
176  fn drop(&mut self) {
177    unsafe { alloc::free(self.ptr) }
178  }
179}
180
181impl Clone for SecretWords {
182  fn clone(&self) -> Self {
183    unsafe {
184      let ptr = alloc::malloc(self.capacity * 8).cast::<Word>();
185
186      copy_nonoverlapping(self.borrow().as_words().as_ptr(), ptr.as_ptr(), self.capacity);
187      alloc::mprotect(ptr, alloc::Prot::NoAccess);
188
189      SecretWords {
190        ptr,
191        size: self.size,
192        capacity: self.capacity,
193        locks: AtomicIsize::new(0),
194      }
195    }
196  }
197}
198
199impl From<&mut [u8]> for SecretWords {
200  fn from(bytes: &mut [u8]) -> Self {
201    if bytes.len() % 8 != 0 {
202      warn!("Bytes not aligned to 8 bytes. Probably these are not the bytes you are looking for.");
203    }
204    unsafe {
205      let len = bytes.len() / 8;
206      let ptr = alloc::malloc(len * 8).cast();
207
208      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr() as *mut u8, len * 8);
209      memory::memzero(bytes.as_mut_ptr(), bytes.len());
210      alloc::mprotect(ptr, alloc::Prot::NoAccess);
211
212      SecretWords {
213        ptr,
214        size: len,
215        capacity: len,
216        locks: AtomicIsize::new(0),
217      }
218    }
219  }
220}
221
222impl From<&mut [Word]> for SecretWords {
223  fn from(words: &mut [Word]) -> Self {
224    unsafe {
225      let ptr = alloc::malloc(words.len() * 8).cast();
226
227      copy_nonoverlapping(words.as_ptr(), ptr.as_ptr(), words.len());
228      memory::memzero(words.as_mut_ptr() as *mut u8, words.len() * 8);
229      alloc::mprotect(ptr, alloc::Prot::NoAccess);
230
231      SecretWords {
232        ptr,
233        size: words.len(),
234        capacity: words.len(),
235        locks: AtomicIsize::new(0),
236      }
237    }
238  }
239}
240
241impl From<Vec<u8>> for SecretWords {
242  fn from(mut bytes: Vec<u8>) -> Self {
243    if bytes.len() % 8 != 0 {
244      warn!("Bytes not aligned to 8 bytes. Probably these are not the bytes you are looking for.");
245    }
246    unsafe {
247      let len = bytes.len() / 8;
248      let ptr = alloc::malloc(len * 8).cast();
249
250      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr() as *mut u8, len * 8);
251      memory::memzero(bytes.as_mut_ptr(), bytes.len());
252      alloc::mprotect(ptr, alloc::Prot::NoAccess);
253
254      SecretWords {
255        ptr,
256        size: len,
257        capacity: len,
258        locks: AtomicIsize::new(0),
259      }
260    }
261  }
262}
263
264pub struct Ref<'a> {
265  words: &'a SecretWords,
266}
267
268impl Ref<'_> {
269  pub fn as_bytes(&self) -> &[u8] {
270    unsafe {
271      let words = slice::from_raw_parts(self.words.ptr.as_ptr(), self.words.size);
272      Word::words_to_bytes(words)
273    }
274  }
275
276  pub fn as_words(&self) -> &[Word] {
277    unsafe { slice::from_raw_parts(self.words.ptr.as_ptr(), self.words.size) }
278  }
279}
280
281impl Drop for Ref<'_> {
282  fn drop(&mut self) {
283    self.words.unlock_read()
284  }
285}
286
287impl Deref for Ref<'_> {
288  type Target = [u8];
289
290  fn deref(&self) -> &Self::Target {
291    unsafe { slice::from_raw_parts(self.words.ptr.as_ptr() as *const u8, self.words.size * 8) }
292  }
293}
294
295impl AsRef<[u8]> for Ref<'_> {
296  fn as_ref(&self) -> &[u8] {
297    self.as_bytes()
298  }
299}
300
301pub struct RefMut<'a> {
302  words: &'a mut SecretWords,
303}
304
305impl Drop for RefMut<'_> {
306  fn drop(&mut self) {
307    self.words.unlock_write()
308  }
309}
310
311impl Deref for RefMut<'_> {
312  type Target = [u8];
313
314  fn deref(&self) -> &Self::Target {
315    unsafe { slice::from_raw_parts(self.words.ptr.as_ptr() as *const u8, self.words.size * 8) }
316  }
317}
318
319impl DerefMut for RefMut<'_> {
320  fn deref_mut(&mut self) -> &mut Self::Target {
321    unsafe { slice::from_raw_parts_mut(self.words.ptr.as_ptr() as *mut u8, self.words.size * 8) }
322  }
323}
324
325impl AsRef<[u8]> for RefMut<'_> {
326  fn as_ref(&self) -> &[u8] {
327    unsafe { slice::from_raw_parts(self.words.ptr.as_ptr() as *const u8, self.words.size * 8) }
328  }
329}
330
331impl AsMut<[u8]> for RefMut<'_> {
332  fn as_mut(&mut self) -> &mut [u8] {
333    unsafe { slice::from_raw_parts_mut(self.words.ptr.as_ptr() as *mut u8, self.words.size * 8) }
334  }
335}
336
337pub struct SecureHHeapAllocator {
338  owned_memory: Vec<SecretWords>,
339  next_size: u32,
340  allocation_strategy: AllocationStrategy,
341}
342
343unsafe impl Allocator for SecureHHeapAllocator {
344  fn allocate_segment(&mut self, minimum_size: u32) -> (*mut u8, u32) {
345    let size = ::std::cmp::max(minimum_size, self.next_size);
346    let mut new_words = SecretWords::zeroed(size as usize);
347    let ptr = new_words.as_mut_ptr() as *mut u8;
348    self.owned_memory.push(new_words);
349
350    if let AllocationStrategy::GrowHeuristically = self.allocation_strategy {
351      self.next_size += size;
352    }
353    (ptr, size)
354  }
355
356  unsafe fn deallocate_segment(&mut self, _ptr: *mut u8, _word_size: u32, _words_used: u32) {
357    self.next_size = SUGGESTED_FIRST_SEGMENT_WORDS;
358  }
359}
360
361impl Default for SecureHHeapAllocator {
362  fn default() -> Self {
363    SecureHHeapAllocator {
364      owned_memory: Vec::new(),
365      next_size: SUGGESTED_FIRST_SEGMENT_WORDS,
366      allocation_strategy: SUGGESTED_ALLOCATION_STRATEGY,
367    }
368  }
369}
370
371#[cfg(test)]
372mod tests {
373  use byteorder::{BigEndian, ByteOrder};
374  use rand::{distributions, thread_rng, Rng};
375  use spectral::prelude::*;
376  use std::iter;
377
378  use super::*;
379
380  fn assert_slices_equal(actual: &[u8], expected: &[u8]) {
381    assert!(actual == expected)
382  }
383
384  fn word_from_u64(w: u64) -> Word {
385    capnp::word(
386      (w & 0xff) as u8,
387      ((w >> 8) & 0xff) as u8,
388      ((w >> 16) & 0xff) as u8,
389      ((w >> 24) & 0xff) as u8,
390      ((w >> 32) & 0xff) as u8,
391      ((w >> 40) & 0xff) as u8,
392      ((w >> 48) & 0xff) as u8,
393      ((w >> 56) & 0xff) as u8,
394    )
395  }
396
397  #[test]
398  fn test_borrow_read_only() {
399    let rng = thread_rng();
400    let mut source = rng
401      .sample_iter::<u8, _>(&distributions::Standard)
402      .filter(|w| *w != 0)
403      .take(200 * 8)
404      .collect::<Vec<u8>>();
405    let expected = source.clone();
406
407    for w in source.iter() {
408      assert_that(&w).is_not_equal_to(&0);
409    }
410
411    let guarded = SecretWords::from(source.as_mut_slice());
412
413    assert_that(&guarded.len()).is_equal_to(source.len() / 8);
414    assert_that(&guarded.borrow().as_words().len()).is_equal_to(source.len() / 8);
415
416    for w in source.iter() {
417      assert_that(&w).is_equal_to(&0);
418    }
419
420    assert_that(&guarded.locks()).is_equal_to(0);
421    assert_slices_equal(&guarded.borrow(), &expected);
422    assert_that(&guarded.locks()).is_equal_to(0);
423
424    {
425      let ref1 = guarded.borrow();
426      let ref2 = guarded.borrow();
427      let ref3 = guarded.borrow();
428
429      assert_that(&ref1.len()).is_equal_to(200 * 8);
430      assert_that(&guarded.locks()).is_equal_to(3);
431      assert_slices_equal(&ref1, &expected);
432      assert_slices_equal(&ref2, &expected);
433      assert_slices_equal(&ref3, &expected);
434    }
435    assert_that(&guarded.locks()).is_equal_to(0);
436  }
437
438  #[test]
439  fn test_zeroed() {
440    let guarded = SecretWords::zeroed(200);
441
442    assert_that(&guarded.len()).is_equal_to(200);
443    assert_that(&guarded.capacity()).is_equal_to(200);
444
445    {
446      let ref1 = guarded.borrow();
447
448      assert_that(&ref1.len()).is_equal_to(200 * 8);
449      for w in ref1.as_words() {
450        assert_that(&w).is_equal_to(&word_from_u64(0));
451      }
452    }
453  }
454
455  #[test]
456  fn test_borrow_read_write() {
457    let mut rng = thread_rng();
458    let mut source = iter::repeat(())
459      .map(|_| rng.sample(distributions::Standard))
460      .filter(|w| *w != 0)
461      .take(200 * 8)
462      .collect::<Vec<u8>>();
463    let source2 = iter::repeat(())
464      .map(|_| rng.sample(distributions::Standard))
465      .filter(|w| *w != 0)
466      .take(200 * 8)
467      .collect::<Vec<u8>>();
468    let expected = source.clone();
469    let expected2 = source2.clone();
470
471    for w in source.iter() {
472      assert_that(&w).is_not_equal_to(&0);
473    }
474
475    let mut guarded = SecretWords::from(source.as_mut_slice());
476
477    for w in source.iter() {
478      assert_that(&w).is_equal_to(&0);
479    }
480
481    assert_that(&guarded.locks()).is_equal_to(0);
482    assert_slices_equal(&guarded.borrow(), &expected);
483
484    guarded.borrow_mut().as_mut().copy_from_slice(&source2);
485
486    assert_that(&guarded.locks()).is_equal_to(0);
487    assert_slices_equal(&guarded.borrow(), &expected2);
488  }
489
490  #[test]
491  fn test_from_unaligned_source() {
492    let mut chunks = [0u8; 16];
493
494    BigEndian::write_u64(&mut chunks[0..8], 0x1234_5678_1234_5678);
495    BigEndian::write_u64(&mut chunks[8..16], 0xf0e1_d2c3_b4a5_9687);
496
497    let mut bytes1 = [0u8; 100 * 16 + 1];
498    let mut bytes2 = [0u8; 100 * 16 + 3];
499
500    for i in 0..100 {
501      bytes1[i * 16 + 1..i * 16 + 1 + 16].copy_from_slice(&chunks);
502      bytes2[i * 16 + 3..i * 16 + 3 + 16].copy_from_slice(&chunks);
503    }
504
505    let guarded1 = SecretWords::from(&mut bytes1[1..]);
506    let guarded2 = SecretWords::from(&mut bytes2[3..]);
507
508    for b in &bytes1[..] {
509      assert_that(b).is_equal_to(0);
510    }
511    for b in &bytes2[..] {
512      assert_that(b).is_equal_to(0);
513    }
514
515    assert_that(&guarded1.len()).is_equal_to(200);
516    assert_that(&guarded2.len()).is_equal_to(200);
517
518    for (idx, w) in guarded1.borrow().chunks(8).enumerate() {
519      if idx % 2 == 0 {
520        assert_that(&w).is_equal_to(&[0x12u8, 0x34u8, 0x56u8, 0x78u8, 0x12u8, 0x34u8, 0x56u8, 0x78u8][..]);
521      } else {
522        assert_that(&w).is_equal_to(&[0xf0u8, 0xe1u8, 0xd2u8, 0xc3u8, 0xb4u8, 0xa5u8, 0x96u8, 0x87u8][..]);
523      }
524    }
525    for (idx, w) in guarded2.borrow().chunks(8).enumerate() {
526      if idx % 2 == 0 {
527        assert_that(&w).is_equal_to(&[0x12u8, 0x34u8, 0x56u8, 0x78u8, 0x12u8, 0x34u8, 0x56u8, 0x78u8][..]);
528      } else {
529        assert_that(&w).is_equal_to(&[0xf0u8, 0xe1u8, 0xd2u8, 0xc3u8, 0xb4u8, 0xa5u8, 0x96u8, 0x87u8][..]);
530      }
531    }
532  }
533}