t_rust_less_lib/memguard/
bytes.rs1use super::alloc;
2use super::memory;
3use serde::{Deserialize, Serialize};
4use std::convert::{AsMut, AsRef};
5use std::fmt;
6use std::io;
7use std::ops::{Deref, DerefMut};
8use std::ptr::{copy_nonoverlapping, NonNull};
9use std::slice;
10use std::sync::atomic::{AtomicIsize, Ordering};
11use zeroize::Zeroize;
12
13use crate::memguard::ZeroizeBytesBuffer;
14use byteorder::WriteBytesExt;
15use rand::{CryptoRng, RngCore};
16
17pub struct SecretBytes {
34 ptr: NonNull<u8>,
35 size: usize,
36 capacity: usize,
37 locks: AtomicIsize,
38}
39
40impl SecretBytes {
41 pub fn from_secured(bytes: &[u8]) -> Self {
47 unsafe {
48 let ptr = alloc::malloc(bytes.len());
49
50 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
51 alloc::mprotect(ptr, alloc::Prot::NoAccess);
52
53 SecretBytes {
54 ptr,
55 size: bytes.len(),
56 capacity: bytes.len(),
57 locks: AtomicIsize::new(0),
58 }
59 }
60 }
61
62 pub fn with_capacity(capacity: usize) -> SecretBytes {
63 unsafe {
64 let ptr = alloc::malloc(capacity);
65
66 alloc::mprotect(ptr, alloc::Prot::NoAccess);
67
68 SecretBytes {
69 ptr,
70 size: 0,
71 capacity,
72 locks: AtomicIsize::new(0),
73 }
74 }
75 }
76
77 pub fn with_capacity_for_chars(capacity_for_chars: usize) -> SecretBytes {
78 Self::with_capacity(capacity_for_chars * 4)
80 }
81
82 pub fn zeroed(size: usize) -> SecretBytes {
83 unsafe {
84 let ptr = alloc::malloc(size);
85
86 memory::memzero(ptr.as_ptr(), size);
87 alloc::mprotect(ptr, alloc::Prot::NoAccess);
88
89 SecretBytes {
90 ptr,
91 size,
92 capacity: size,
93 locks: AtomicIsize::new(0),
94 }
95 }
96 }
97
98 pub fn random<T>(rng: &mut T, size: usize) -> SecretBytes
99 where
100 T: RngCore + CryptoRng,
101 {
102 unsafe {
103 let ptr = alloc::malloc(size);
104
105 rng.fill_bytes(slice::from_raw_parts_mut(ptr.as_ptr(), size));
106 alloc::mprotect(ptr, alloc::Prot::NoAccess);
107
108 SecretBytes {
109 ptr,
110 size,
111 capacity: size,
112 locks: AtomicIsize::new(0),
113 }
114 }
115 }
116
117 pub fn is_empty(&self) -> bool {
118 self.size == 0
119 }
120
121 pub fn len(&self) -> usize {
122 self.size
123 }
124
125 pub fn capacity(&self) -> usize {
126 self.capacity
127 }
128
129 pub fn borrow(&self) -> Ref {
130 self.lock_read();
131 Ref { bytes: self }
132 }
133
134 pub fn borrow_mut(&mut self) -> RefMut {
135 self.lock_write();
136 RefMut { bytes: self }
137 }
138
139 pub fn locks(&self) -> isize {
140 self.locks.load(Ordering::Relaxed)
141 }
142
143 fn lock_read(&self) {
144 let locks = self.locks.fetch_add(1, Ordering::Relaxed);
145
146 assert!(locks >= 0);
147
148 if locks == 0 {
149 unsafe {
150 alloc::mprotect(self.ptr, alloc::Prot::ReadOnly);
151 }
152 }
153 }
154
155 fn unlock_read(&self) {
156 let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
157
158 assert!(locks > 0);
159
160 if locks == 1 {
161 unsafe {
162 alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
163 }
164 }
165 }
166
167 fn lock_write(&mut self) {
168 let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
169
170 assert!(locks == 0);
171
172 unsafe {
173 alloc::mprotect(self.ptr, alloc::Prot::ReadWrite);
174 }
175 }
176
177 fn unlock_write(&mut self) {
178 let locks = self.locks.fetch_add(1, Ordering::Relaxed);
179
180 assert!(locks == -1);
181
182 unsafe {
183 alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
184 }
185 }
186}
187
188unsafe impl Send for SecretBytes {}
189
190unsafe impl Sync for SecretBytes {}
191
192impl Zeroize for SecretBytes {
193 fn zeroize(&mut self) {
194 self.lock_write();
195 unsafe {
196 memory::memzero(self.ptr.as_ptr(), self.capacity);
197 }
198 self.unlock_write();
199 }
200}
201
202impl Drop for SecretBytes {
203 fn drop(&mut self) {
204 unsafe { alloc::free(self.ptr) }
205 }
206}
207
208impl Clone for SecretBytes {
209 fn clone(&self) -> Self {
210 unsafe {
211 let ptr = alloc::malloc(self.capacity);
212
213 copy_nonoverlapping(self.borrow().as_ref().as_ptr(), ptr.as_ptr(), self.capacity);
214 alloc::mprotect(ptr, alloc::Prot::NoAccess);
215
216 SecretBytes {
217 ptr,
218 size: self.size,
219 capacity: self.capacity,
220 locks: AtomicIsize::new(0),
221 }
222 }
223 }
224}
225
226impl PartialEq for SecretBytes {
227 fn eq(&self, other: &Self) -> bool {
228 self.borrow().as_bytes() == other.borrow().as_bytes()
229 }
230}
231
232impl Eq for SecretBytes {}
233
234impl From<&mut [u8]> for SecretBytes {
235 fn from(bytes: &mut [u8]) -> Self {
236 unsafe {
237 let ptr = alloc::malloc(bytes.len());
238
239 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
240 memory::memzero(bytes.as_mut_ptr(), bytes.len());
241 alloc::mprotect(ptr, alloc::Prot::NoAccess);
242
243 SecretBytes {
244 ptr,
245 size: bytes.len(),
246 capacity: bytes.len(),
247 locks: AtomicIsize::new(0),
248 }
249 }
250 }
251}
252
253impl From<Vec<u8>> for SecretBytes {
254 fn from(mut bytes: Vec<u8>) -> Self {
255 unsafe {
256 let ptr = alloc::malloc(bytes.len());
257
258 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
259 memory::memzero(bytes.as_mut_ptr(), bytes.len());
260 alloc::mprotect(ptr, alloc::Prot::NoAccess);
261
262 SecretBytes {
263 ptr,
264 size: bytes.len(),
265 capacity: bytes.len(),
266 locks: AtomicIsize::new(0),
267 }
268 }
269 }
270}
271
272impl From<String> for SecretBytes {
273 fn from(mut str: String) -> Self {
274 unsafe {
275 let bytes = str.as_bytes_mut();
276 let ptr = alloc::malloc(bytes.len());
277
278 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr(), bytes.len());
279 memory::memzero(bytes.as_mut_ptr(), bytes.len());
280 alloc::mprotect(ptr, alloc::Prot::NoAccess);
281
282 SecretBytes {
283 ptr,
284 size: bytes.len(),
285 capacity: bytes.len(),
286 locks: AtomicIsize::new(0),
287 }
288 }
289 }
290}
291
292impl Serialize for SecretBytes {
298 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
299 where
300 S: serde::Serializer,
301 {
302 serializer.serialize_bytes(self.borrow().as_bytes())
303 }
304}
305
306impl<'de> Deserialize<'de> for SecretBytes {
307 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
308 where
309 D: serde::Deserializer<'de>,
310 {
311 deserializer.deserialize_bytes(SafeBytesVisitor())
312 }
313}
314
315struct SafeBytesVisitor();
316
317impl<'de> serde::de::Visitor<'de> for SafeBytesVisitor {
318 type Value = SecretBytes;
319
320 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
321 formatter.write_str("a byte array")
322 }
323
324 fn visit_borrowed_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
325 where
326 E: serde::de::Error,
327 {
328 Ok(SecretBytes::from_secured(v))
329 }
330
331 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
332 where
333 A: serde::de::SeqAccess<'de>,
334 {
335 let mut buf = ZeroizeBytesBuffer::with_capacity(seq.size_hint().unwrap_or(1024));
336
337 while let Some(value) = seq.next_element::<u8>()? {
338 buf.write_u8(value).ok();
339 }
340
341 Ok(SecretBytes::from_secured(&buf))
342 }
343}
344
345pub struct Ref<'a> {
346 bytes: &'a SecretBytes,
347}
348
349impl<'a> Ref<'a> {
350 pub fn as_bytes(&self) -> &[u8] {
351 unsafe { slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size) }
352 }
353
354 pub fn as_str(&self) -> &str {
355 unsafe {
356 let bytes = slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size);
357 std::str::from_utf8_unchecked(bytes)
358 }
359 }
360}
361
362impl<'a> Drop for Ref<'a> {
363 fn drop(&mut self) {
364 self.bytes.unlock_read()
365 }
366}
367
368impl<'a> Deref for Ref<'a> {
369 type Target = [u8];
370
371 fn deref(&self) -> &Self::Target {
372 self.as_bytes()
373 }
374}
375
376impl<'a> AsRef<[u8]> for Ref<'a> {
377 fn as_ref(&self) -> &[u8] {
378 self.as_bytes()
379 }
380}
381
382pub struct RefMut<'a> {
383 bytes: &'a mut SecretBytes,
384}
385
386impl<'a> RefMut<'a> {
387 pub fn clear(&mut self) {
388 unsafe {
389 memory::memzero(self.bytes.ptr.as_ptr(), self.bytes.capacity);
390 self.bytes.size = 0;
391 }
392 }
393
394 pub fn append_char(&mut self, ch: char) {
395 let ch_len = ch.len_utf8();
396
397 assert!(ch_len + self.bytes.size <= self.bytes.capacity);
398
399 unsafe {
400 let bytes_with_extra = slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size + ch_len);
401 ch.encode_utf8(&mut bytes_with_extra[self.bytes.size..]);
402 }
403 self.bytes.size += ch_len;
404 }
405
406 pub fn remove_char(&mut self) {
407 unsafe {
408 let bytes = slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size);
409 let tail_len = match std::str::from_utf8_unchecked(bytes).chars().last() {
410 Some(ch) => ch.len_utf8(),
411 None => return,
412 };
413 assert!(tail_len <= self.bytes.size);
414 for b in &mut bytes[self.bytes.size - tail_len..] {
415 *b = 0
416 }
417
418 self.bytes.size -= tail_len;
419 }
420 }
421}
422
423impl<'a> Drop for RefMut<'a> {
424 fn drop(&mut self) {
425 self.bytes.unlock_write()
426 }
427}
428
429impl<'a> Deref for RefMut<'a> {
430 type Target = [u8];
431
432 fn deref(&self) -> &Self::Target {
433 unsafe { slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size) }
434 }
435}
436
437impl<'a> DerefMut for RefMut<'a> {
438 fn deref_mut(&mut self) -> &mut Self::Target {
439 unsafe { slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size) }
440 }
441}
442
443impl<'a> AsRef<[u8]> for RefMut<'a> {
444 fn as_ref(&self) -> &[u8] {
445 unsafe { slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.size) }
446 }
447}
448
449impl<'a> AsMut<[u8]> for RefMut<'a> {
450 fn as_mut(&mut self) -> &mut [u8] {
451 unsafe { slice::from_raw_parts_mut(self.bytes.ptr.as_ptr(), self.bytes.size) }
452 }
453}
454
455impl<'a> io::Write for RefMut<'a> {
456 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
457 let available = self.bytes.capacity() - self.bytes.size;
458
459 if available == 0 {
460 return Err(io::ErrorKind::WriteZero.into());
461 }
462 let transfer = available.min(buf.len());
463
464 unsafe {
465 copy_nonoverlapping(buf.as_ptr(), self.bytes.ptr.as_ptr().add(self.bytes.size), transfer);
466 }
467 self.bytes.size += transfer;
468
469 Ok(transfer)
470 }
471
472 fn flush(&mut self) -> io::Result<()> {
473 Ok(())
474 }
475}
476
477impl std::fmt::Debug for SecretBytes {
478 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
479 write!(f, "<Secret>")
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use rand::{distributions, thread_rng, Rng};
486 use spectral::prelude::*;
487 use std::iter;
488
489 use super::*;
490 use crate::memguard::ZeroizeBytesBuffer;
491
492 fn assert_slices_equal(actual: &[u8], expected: &[u8]) {
493 assert!(actual == expected)
494 }
495
496 #[test]
497 fn test_borrow_read_only() {
498 let rng = thread_rng();
499 let mut source = rng
500 .sample_iter(&distributions::Standard)
501 .filter(|b| *b != 0)
502 .take(200)
503 .collect::<Vec<u8>>();
504 let expected = source.clone();
505
506 for b in source.iter() {
507 assert_that(b).is_not_equal_to(0);
508 }
509
510 let guarded = SecretBytes::from(source.as_mut_slice());
511
512 assert_that(&guarded.len()).is_equal_to(source.len());
513 assert_that(&guarded.borrow().as_ref().len()).is_equal_to(source.len());
514
515 for b in source.iter() {
516 assert_that(b).is_equal_to(0);
517 }
518
519 assert_that(&guarded.locks()).is_equal_to(0);
520 assert_slices_equal(&guarded.borrow(), &expected);
521 assert_that(&guarded.locks()).is_equal_to(0);
522
523 {
524 let ref1 = guarded.borrow();
525 let ref2 = guarded.borrow();
526 let ref3 = guarded.borrow();
527
528 assert_that(&ref1.len()).is_equal_to(200);
529 assert_that(&guarded.locks()).is_equal_to(3);
530 assert_slices_equal(&ref1, &expected);
531 assert_slices_equal(&ref2, &expected);
532 assert_slices_equal(&ref3, &expected);
533 }
534 assert_that(&guarded.locks()).is_equal_to(0);
535 }
536
537 #[test]
538 fn test_zeroed() {
539 let guarded = SecretBytes::zeroed(200);
540
541 assert_that(&guarded.len()).is_equal_to(200);
542 assert_that(&guarded.capacity()).is_equal_to(200);
543
544 {
545 let ref1 = guarded.borrow();
546
547 assert_that(&ref1.len()).is_equal_to(200);
548 for b in ref1.as_ref() {
549 assert_that(b).is_equal_to(0);
550 }
551 }
552 }
553
554 #[test]
555 fn test_borrow_read_write() {
556 let mut rng = thread_rng();
557 let mut source = iter::repeat(())
558 .map(|_| rng.sample(distributions::Standard))
559 .filter(|b| *b != 0)
560 .take(200)
561 .collect::<Vec<u8>>();
562 let source2 = iter::repeat(())
563 .map(|_| rng.sample(distributions::Standard))
564 .filter(|b| *b != 0)
565 .take(200)
566 .collect::<Vec<u8>>();
567 let expected = source.clone();
568 let expected2 = source2.clone();
569
570 for b in source.iter() {
571 assert_that(b).is_not_equal_to(0);
572 }
573
574 let mut guarded = SecretBytes::from(source.as_mut_slice());
575
576 for b in source.iter() {
577 assert_that(b).is_equal_to(0);
578 }
579
580 assert_that(&guarded.locks()).is_equal_to(0);
581 assert_slices_equal(&guarded.borrow(), &expected);
582
583 guarded.borrow_mut().as_mut().copy_from_slice(&source2);
584
585 assert_that(&guarded.locks()).is_equal_to(0);
586 assert_slices_equal(&guarded.borrow(), &expected2);
587 }
588
589 #[test]
590 fn test_init_strong_random() {
591 let mut rng = thread_rng();
592 let random = SecretBytes::random(&mut rng, 32);
593
594 assert_that(&random.len()).is_equal_to(32);
595 assert_that(&random.borrow().as_ref().len()).is_equal_to(32);
596 }
597
598 #[test]
599 fn test_str_like_ops() {
600 let mut secret = SecretBytes::with_capacity_for_chars(20);
601
602 assert_that(&secret.len()).is_equal_to(0);
603 assert_that(&secret.capacity()).is_equal_to(80);
604
605 secret.borrow_mut().append_char('a');
606 assert_that(&secret.len()).is_equal_to(1);
607 secret.borrow_mut().append_char('ä');
608 assert_that(&secret.len()).is_equal_to(3);
609 assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(2);
610 secret.borrow_mut().append_char('€');
611 assert_that(&secret.len()).is_equal_to(6);
612 assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(3);
613 secret.borrow_mut().append_char('ß');
614 assert_that(&secret.len()).is_equal_to(8);
615 assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(4);
616 assert_that(&secret.borrow().as_str()).is_equal_to("aä€ß");
617
618 secret.borrow_mut().remove_char();
619 assert_that(&secret.len()).is_equal_to(6);
620 assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(3);
621 assert_that(&secret.borrow().as_str()).is_equal_to("aä€");
622
623 secret.borrow_mut().remove_char();
624 assert_that(&secret.len()).is_equal_to(3);
625 assert_that(&secret.borrow().as_str().chars().count()).is_equal_to(2);
626 assert_that(&secret.borrow().as_str()).is_equal_to("aä");
627
628 secret.borrow_mut().remove_char();
629 assert_that(&secret.len()).is_equal_to(1);
630
631 secret.borrow_mut().remove_char();
632 assert_that(&secret.len()).is_equal_to(0);
633
634 secret.borrow_mut().remove_char();
635 assert_that(&secret.len()).is_equal_to(0);
636 }
637
638 #[test]
639 fn test_serde_json() {
640 let mut rng = thread_rng();
641 let random = SecretBytes::random(&mut rng, 32);
642 let mut buffer = ZeroizeBytesBuffer::with_capacity(1024);
643
644 serde_json::to_writer(&mut buffer, &random).unwrap();
645
646 let deserialized: SecretBytes = serde_json::from_reader(buffer.as_ref()).unwrap();
647
648 assert_that(&deserialized).is_equal_to(&random);
649 }
650
651 #[test]
652 fn test_serde_rmb() {
653 let mut rng = thread_rng();
654 let random = SecretBytes::random(&mut rng, 32);
655 let mut buffer = ZeroizeBytesBuffer::with_capacity(1024);
656
657 rmp_serde::encode::write_named(&mut buffer, &random).unwrap();
658
659 let deserialized: SecretBytes = rmp_serde::from_read_ref(&buffer).unwrap();
660
661 assert_that(&deserialized).is_equal_to(&random);
662 }
663}