t_rust_less_lib/memguard/
bytes.rs

1use super::alloc;
2use super::memory;
3use serde::{Deserialize, Serialize};
4use std::convert::{AsMut, AsRef};
5use std::fmt;
6use std::io;
7use std::ops::{Deref, DerefMut};
8use std::ptr::{copy_nonoverlapping, NonNull};
9use std::slice;
10use std::sync::atomic::{AtomicIsize, Ordering};
11use zeroize::Zeroize;
12
13use crate::memguard::ZeroizeBytesBuffer;
14use byteorder::WriteBytesExt;
15use rand::{CryptoRng, RngCore};
16
17/// Strictly memory protected bytes contain sensitive data.
18///
19/// This implementation borrows a lot of code and ideas from:
20/// * https://crates.io/crates/memsec
21/// * https://crates.io/crates/secrets
22/// * https://download.libsodium.org/doc/memory_management
23///
24/// `secrets` is not good enough because it relies on libsodium which breaks the desired
25/// portability of this library (at least at the time of this writing).
26///
27/// `memsec` is not
28/// good enough because it focuses on protecting a generic type `T` which size is known at
29/// compile-time. In this library we are dealing with dynamic amounts of sensitive data and
30/// there is no point in securing a `Vec<u8>` via `memsec` ... all we would achieve is protecting
31/// the pointer to sensitive data in unsecured space.
32///
33pub struct SecretBytes {
34  ptr: NonNull<u8>,
35  size: usize,
36  capacity: usize,
37  locks: AtomicIsize,
38}
39
40impl SecretBytes {
41  /// Copy from slice of bytes.
42  ///
43  /// This is not a regular From implementation because the caller has to ensure that
44  /// the original bytes are zeroed out (or are already in some secured memspace.
45  /// This different signature should be a reminder of that.
46  pub fn from_secured(bytes: &[u8]) -> Self {
47    unsafe {
48      let ptr = alloc::malloc(bytes.len());
49
50      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
51      alloc::mprotect(ptr, alloc::Prot::NoAccess);
52
53      SecretBytes {
54        ptr,
55        size: bytes.len(),
56        capacity: bytes.len(),
57        locks: AtomicIsize::new(0),
58      }
59    }
60  }
61
62  pub fn with_capacity(capacity: usize) -> SecretBytes {
63    unsafe {
64      let ptr = alloc::malloc(capacity);
65
66      alloc::mprotect(ptr, alloc::Prot::NoAccess);
67
68      SecretBytes {
69        ptr,
70        size: 0,
71        capacity,
72        locks: AtomicIsize::new(0),
73      }
74    }
75  }
76
77  pub fn with_capacity_for_chars(capacity_for_chars: usize) -> SecretBytes {
78    // UTF-8 chars may be 4 bytes long
79    Self::with_capacity(capacity_for_chars * 4)
80  }
81
82  pub fn zeroed(size: usize) -> SecretBytes {
83    unsafe {
84      let ptr = alloc::malloc(size);
85
86      memory::memzero(ptr.as_ptr(), size);
87      alloc::mprotect(ptr, alloc::Prot::NoAccess);
88
89      SecretBytes {
90        ptr,
91        size,
92        capacity: size,
93        locks: AtomicIsize::new(0),
94      }
95    }
96  }
97
98  pub fn random<T>(rng: &mut T, size: usize) -> SecretBytes
99  where
100    T: RngCore + CryptoRng,
101  {
102    unsafe {
103      let ptr = alloc::malloc(size);
104
105      rng.fill_bytes(slice::from_raw_parts_mut(ptr.as_ptr(), size));
106      alloc::mprotect(ptr, alloc::Prot::NoAccess);
107
108      SecretBytes {
109        ptr,
110        size,
111        capacity: size,
112        locks: AtomicIsize::new(0),
113      }
114    }
115  }
116
117  pub fn is_empty(&self) -> bool {
118    self.size == 0
119  }
120
121  pub fn len(&self) -> usize {
122    self.size
123  }
124
125  pub fn capacity(&self) -> usize {
126    self.capacity
127  }
128
129  pub fn borrow(&self) -> Ref {
130    self.lock_read();
131    Ref { bytes: self }
132  }
133
134  pub fn borrow_mut(&mut self) -> RefMut {
135    self.lock_write();
136    RefMut { bytes: self }
137  }
138
139  pub fn locks(&self) -> isize {
140    self.locks.load(Ordering::Relaxed)
141  }
142
143  fn lock_read(&self) {
144    let locks = self.locks.fetch_add(1, Ordering::Relaxed);
145
146    assert!(locks >= 0);
147
148    if locks == 0 {
149      unsafe {
150        alloc::mprotect(self.ptr, alloc::Prot::ReadOnly);
151      }
152    }
153  }
154
155  fn unlock_read(&self) {
156    let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
157
158    assert!(locks > 0);
159
160    if locks == 1 {
161      unsafe {
162        alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
163      }
164    }
165  }
166
167  fn lock_write(&mut self) {
168    let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
169
170    assert!(locks == 0);
171
172    unsafe {
173      alloc::mprotect(self.ptr, alloc::Prot::ReadWrite);
174    }
175  }
176
177  fn unlock_write(&mut self) {
178    let locks = self.locks.fetch_add(1, Ordering::Relaxed);
179
180    assert!(locks == -1);
181
182    unsafe {
183      alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
184    }
185  }
186}
187
188unsafe impl Send for SecretBytes {}
189
190unsafe impl Sync for SecretBytes {}
191
192impl Zeroize for SecretBytes {
193  fn zeroize(&mut self) {
194    self.lock_write();
195    unsafe {
196      memory::memzero(self.ptr.as_ptr(), self.capacity);
197    }
198    self.unlock_write();
199  }
200}
201
202impl Drop for SecretBytes {
203  fn drop(&mut self) {
204    unsafe { alloc::free(self.ptr) }
205  }
206}
207
208impl Clone for SecretBytes {
209  fn clone(&self) -> Self {
210    unsafe {
211      let ptr = alloc::malloc(self.capacity);
212
213      copy_nonoverlapping(self.borrow().as_ref().as_ptr(), ptr.as_ptr(), self.capacity);
214      alloc::mprotect(ptr, alloc::Prot::NoAccess);
215
216      SecretBytes {
217        ptr,
218        size: self.size,
219        capacity: self.capacity,
220        locks: AtomicIsize::new(0),
221      }
222    }
223  }
224}
225
226impl PartialEq for SecretBytes {
227  fn eq(&self, other: &Self) -> bool {
228    self.borrow().as_bytes() == other.borrow().as_bytes()
229  }
230}
231
232impl Eq for SecretBytes {}
233
234impl From<&mut [u8]> for SecretBytes {
235  fn from(bytes: &mut [u8]) -> Self {
236    unsafe {
237      let ptr = alloc::malloc(bytes.len());
238
239      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
240      memory::memzero(bytes.as_mut_ptr(), bytes.len());
241      alloc::mprotect(ptr, alloc::Prot::NoAccess);
242
243      SecretBytes {
244        ptr,
245        size: bytes.len(),
246        capacity: bytes.len(),
247        locks: AtomicIsize::new(0),
248      }
249    }
250  }
251}
252
253impl From<Vec<u8>> for SecretBytes {
254  fn from(mut bytes: Vec<u8>) -> Self {
255    unsafe {
256      let ptr = alloc::malloc(bytes.len());
257
258      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
259      memory::memzero(bytes.as_mut_ptr(), bytes.len());
260      alloc::mprotect(ptr, alloc::Prot::NoAccess);
261
262      SecretBytes {
263        ptr,
264        size: bytes.len(),
265        capacity: bytes.len(),
266        locks: AtomicIsize::new(0),
267      }
268    }
269  }
270}
271
272impl From<String> for SecretBytes {
273  fn from(mut str: String) -> Self {
274    unsafe {
275      let bytes = str.as_bytes_mut();
276      let ptr = alloc::malloc(bytes.len());
277
278      copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
279      memory::memzero(bytes.as_mut_ptr(), bytes.len());
280      alloc::mprotect(ptr, alloc::Prot::NoAccess);
281
282      SecretBytes {
283        ptr,
284        size: bytes.len(),
285        capacity: bytes.len(),
286        locks: AtomicIsize::new(0),
287      }
288    }
289  }
290}
291
292// Note: This has to be used with care as it is not clear how many temporary buffers
293// the serializer uses or cleans them up correctly.
294// Some examples that are (mostly) safe to use:
295//   serde_json::ser::Serializer writes the bytes as numerical array directly to the output writer
296//   rmp_serde::encode::Serializer write the bytes directly to the output writer
297impl Serialize for SecretBytes {
298  fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
299  where
300    S: serde::Serializer,
301  {
302    serializer.serialize_bytes(self.borrow().as_bytes())
303  }
304}
305
306impl<'de> Deserialize<'de> for SecretBytes {
307  fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
308  where
309    D: serde::Deserializer<'de>,
310  {
311    deserializer.deserialize_bytes(SafeBytesVisitor())
312  }
313}
314
315struct SafeBytesVisitor();
316
317impl<'de> serde::de::Visitor<'de> for SafeBytesVisitor {
318  type Value = SecretBytes;
319
320  fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
321    formatter.write_str("a byte array")
322  }
323
324  fn visit_borrowed_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
325  where
326    E: serde::de::Error,
327  {
328    Ok(SecretBytes::from_secured(v))
329  }
330
331  fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
332  where
333    A: serde::de::SeqAccess<'de>,
334  {
335    let mut buf = ZeroizeBytesBuffer::with_capacity(seq.size_hint().unwrap_or(1024));
336
337    while let Some(value) = seq.next_element::<u8>()? {
338      buf.write_u8(value).ok();
339    }
340
341    Ok(SecretBytes::from_secured(&buf))
342  }
343}
344
345pub struct Ref<'a> {
346  bytes: &'a SecretBytes,
347}
348
349impl<'a> Ref<'a> {
350  pub fn as_bytes(&self) -> &[u8] {
351    unsafe { slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size) }
352  }
353
354  pub fn as_str(&self) -> &str {
355    unsafe {
356      let bytes = slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size);
357      std::str::from_utf8_unchecked(bytes)
358    }
359  }
360}
361
362impl<'a> Drop for Ref<'a> {
363  fn drop(&mut self) {
364    self.bytes.unlock_read()
365  }
366}
367
368impl<'a> Deref for Ref<'a> {
369  type Target = [u8];
370
371  fn deref(&self) -> &Self::Target {
372    self.as_bytes()
373  }
374}
375
376impl<'a> AsRef<[u8]> for Ref<'a> {
377  fn as_ref(&self) -> &[u8] {
378    self.as_bytes()
379  }
380}
381
382pub struct RefMut<'a> {
383  bytes: &'a mut SecretBytes,
384}
385
386impl<'a> RefMut<'a> {
387  pub fn clear(&mut self) {
388    unsafe {
389      memory::memzero(self.bytes.ptr.as_ptr(), self.bytes.capacity);
390      self.bytes.size = 0;
391    }
392  }
393
394  pub fn append_char(&mut self, ch: char) {
395    let ch_len = ch.len_utf8();
396
397    assert!(ch_len + self.bytes.size <= self.bytes.capacity);
398
399    unsafe {
400      let bytes_with_extra = slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size + ch_len);
401      ch.encode_utf8(&mut bytes_with_extra[self.bytes.size..]);
402    }
403    self.bytes.size += ch_len;
404  }
405
406  pub fn remove_char(&mut self) {
407    unsafe {
408      let bytes = slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size);
409      let tail_len = match std::str::from_utf8_unchecked(bytes).chars().last() {
410        Some(ch) => ch.len_utf8(),
411        None => return,
412      };
413      assert!(tail_len <= self.bytes.size);
414      for b in &mut bytes[self.bytes.size - tail_len..] {
415        *b = 0
416      }
417
418      self.bytes.size -= tail_len;
419    }
420  }
421}
422
423impl<'a> Drop for RefMut<'a> {
424  fn drop(&mut self) {
425    self.bytes.unlock_write()
426  }
427}
428
429impl<'a> Deref for RefMut<'a> {
430  type Target = [u8];
431
432  fn deref(&self) -> &Self::Target {
433    unsafe { slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size) }
434  }
435}
436
437impl<'a> DerefMut for RefMut<'a> {
438  fn deref_mut(&mut self) -> &mut Self::Target {
439    unsafe { slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size) }
440  }
441}
442
443impl<'a> AsRef<[u8]> for RefMut<'a> {
444  fn as_ref(&self) -> &[u8] {
445    unsafe { slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size) }
446  }
447}
448
449impl<'a> AsMut<[u8]> for RefMut<'a> {
450  fn as_mut(&mut self) -> &mut [u8] {
451    unsafe { slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size) }
452  }
453}
454
455impl<'a> io::Write for RefMut<'a> {
456  fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
457    let available = self.bytes.capacity() - self.bytes.size;
458
459    if available == 0 {
460      return Err(io::ErrorKind::WriteZero.into());
461    }
462    let transfer = available.min(buf.len());
463
464    unsafe {
465      copy_nonoverlapping(buf.as_ptr(), self.bytes.ptr.as_ptr().add(self.bytes.size), transfer);
466    }
467    self.bytes.size += transfer;
468
469    Ok(transfer)
470  }
471
472  fn flush(&mut self) -> io::Result<()> {
473    Ok(())
474  }
475}
476
477impl std::fmt::Debug for SecretBytes {
478  fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
479    write!(f, "<Secret>")
480  }
481}
482
483#[cfg(test)]
484mod tests {
485  use rand::{distributions, thread_rng, Rng};
486  use spectral::prelude::*;
487  use std::iter;
488
489  use super::*;
490  use crate::memguard::ZeroizeBytesBuffer;
491
492  fn assert_slices_equal(actual: &[u8], expected: &[u8]) {
493    assert!(actual == expected)
494  }
495
496  #[test]
497  fn test_borrow_read_only() {
498    let rng = thread_rng();
499    let mut source = rng
500      .sample_iter(&distributions::Standard)
501      .filter(|b| *b != 0)
502      .take(200)
503      .collect::<Vec<u8>>();
504    let expected = source.clone();
505
506    for b in source.iter() {
507      assert_that(b).is_not_equal_to(0);
508    }
509
510    let guarded = SecretBytes::from(source.as_mut_slice());
511
512    assert_that(&guarded.len()).is_equal_to(source.len());
513    assert_that(&guarded.borrow().as_ref().len()).is_equal_to(source.len());
514
515    for b in source.iter() {
516      assert_that(b).is_equal_to(0);
517    }
518
519    assert_that(&guarded.locks()).is_equal_to(0);
520    assert_slices_equal(&guarded.borrow(), &expected);
521    assert_that(&guarded.locks()).is_equal_to(0);
522
523    {
524      let ref1 = guarded.borrow();
525      let ref2 = guarded.borrow();
526      let ref3 = guarded.borrow();
527
528      assert_that(&ref1.len()).is_equal_to(200);
529      assert_that(&guarded.locks()).is_equal_to(3);
530      assert_slices_equal(&ref1, &expected);
531      assert_slices_equal(&ref2, &expected);
532      assert_slices_equal(&ref3, &expected);
533    }
534    assert_that(&guarded.locks()).is_equal_to(0);
535  }
536
537  #[test]
538  fn test_zeroed() {
539    let guarded = SecretBytes::zeroed(200);
540
541    assert_that(&guarded.len()).is_equal_to(200);
542    assert_that(&guarded.capacity()).is_equal_to(200);
543
544    {
545      let ref1 = guarded.borrow();
546
547      assert_that(&ref1.len()).is_equal_to(200);
548      for b in ref1.as_ref() {
549        assert_that(b).is_equal_to(0);
550      }
551    }
552  }
553
554  #[test]
555  fn test_borrow_read_write() {
556    let mut rng = thread_rng();
557    let mut source = iter::repeat(())
558      .map(|_| rng.sample(distributions::Standard))
559      .filter(|b| *b != 0)
560      .take(200)
561      .collect::<Vec<u8>>();
562    let source2 = iter::repeat(())
563      .map(|_| rng.sample(distributions::Standard))
564      .filter(|b| *b != 0)
565      .take(200)
566      .collect::<Vec<u8>>();
567    let expected = source.clone();
568    let expected2 = source2.clone();
569
570    for b in source.iter() {
571      assert_that(b).is_not_equal_to(0);
572    }
573
574    let mut guarded = SecretBytes::from(source.as_mut_slice());
575
576    for b in source.iter() {
577      assert_that(b).is_equal_to(0);
578    }
579
580    assert_that(&guarded.locks()).is_equal_to(0);
581    assert_slices_equal(&guarded.borrow(), &expected);
582
583    guarded.borrow_mut().as_mut().copy_from_slice(&source2);
584
585    assert_that(&guarded.locks()).is_equal_to(0);
586    assert_slices_equal(&guarded.borrow(), &expected2);
587  }
588
589  #[test]
590  fn test_init_strong_random() {
591    let mut rng = thread_rng();
592    let random = SecretBytes::random(&mut rng, 32);
593
594    assert_that(&random.len()).is_equal_to(32);
595    assert_that(&random.borrow().as_ref().len()).is_equal_to(32);
596  }
597
598  #[test]
599  fn test_str_like_ops() {
600    let mut secret = SecretBytes::with_capacity_for_chars(20);
601
602    assert_that(&secret.len()).is_equal_to(0);
603    assert_that(&secret.capacity()).is_equal_to(80);
604
605    secret.borrow_mut().append_char('a');
606    assert_that(&secret.len()).is_equal_to(1);
607    secret.borrow_mut().append_char('ä');
608    assert_that(&secret.len()).is_equal_to(3);
609    assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(2);
610    secret.borrow_mut().append_char('€');
611    assert_that(&secret.len()).is_equal_to(6);
612    assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(3);
613    secret.borrow_mut().append_char('ß');
614    assert_that(&secret.len()).is_equal_to(8);
615    assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(4);
616    assert_that(&secret.borrow().as_str()).is_equal_to("aä€ß");
617
618    secret.borrow_mut().remove_char();
619    assert_that(&secret.len()).is_equal_to(6);
620    assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(3);
621    assert_that(&secret.borrow().as_str()).is_equal_to("aä€");
622
623    secret.borrow_mut().remove_char();
624    assert_that(&secret.len()).is_equal_to(3);
625    assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(2);
626    assert_that(&secret.borrow().as_str()).is_equal_to("aä");
627
628    secret.borrow_mut().remove_char();
629    assert_that(&secret.len()).is_equal_to(1);
630
631    secret.borrow_mut().remove_char();
632    assert_that(&secret.len()).is_equal_to(0);
633
634    secret.borrow_mut().remove_char();
635    assert_that(&secret.len()).is_equal_to(0);
636  }
637
638  #[test]
639  fn test_serde_json() {
640    let mut rng = thread_rng();
641    let random = SecretBytes::random(&mut rng, 32);
642    let mut buffer = ZeroizeBytesBuffer::with_capacity(1024);
643
644    serde_json::to_writer(&mut buffer, &random).unwrap();
645
646    let deserialized: SecretBytes = serde_json::from_reader(buffer.as_ref()).unwrap();
647
648    assert_that(&deserialized).is_equal_to(&random);
649  }
650
651  #[test]
652  fn test_serde_rmb() {
653    let mut rng = thread_rng();
654    let random = SecretBytes::random(&mut rng, 32);
655    let mut buffer = ZeroizeBytesBuffer::with_capacity(1024);
656
657    rmp_serde::encode::write_named(&mut buffer, &random).unwrap();
658
659    let deserialized: SecretBytes = rmp_serde::from_read_ref(&buffer).unwrap();
660
661    assert_that(&deserialized).is_equal_to(&random);
662  }
663}