t_rust_less_lib/memguard/
words.rs1use super::alloc;
2use super::memory;
3use capnp::message::{AllocationStrategy, Allocator, SUGGESTED_ALLOCATION_STRATEGY, SUGGESTED_FIRST_SEGMENT_WORDS};
4use capnp::Word;
5use log::warn;
6use std::convert::{AsMut, AsRef};
7use std::ops::{Deref, DerefMut};
8use std::ptr::{copy_nonoverlapping, NonNull};
9use std::slice;
10use std::sync::atomic::{AtomicIsize, Ordering};
11
12pub struct SecretWords {
29 ptr: NonNull<Word>,
30 size: usize,
31 capacity: usize,
32 locks: AtomicIsize,
33}
34
35impl SecretWords {
36 pub fn from_secured(bytes: &[u8]) -> Self {
42 if bytes.len() % 8 != 0 {
43 warn!("Bytes not aligned to 8 bytes. Probably these are not the bytes you are looking for.");
44 }
45 unsafe {
46 let len = bytes.len() / 8;
47 let ptr = alloc::malloc(len * 8).cast();
48
49 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr() as *mut u8, len * 8);
50 alloc::mprotect(ptr, alloc::Prot::NoAccess);
51
52 SecretWords {
53 ptr,
54 size: len,
55 capacity: len,
56 locks: AtomicIsize::new(0),
57 }
58 }
59 }
60
61 pub fn with_capacity(capacity: usize) -> SecretWords {
62 unsafe {
63 let ptr = alloc::malloc(capacity * 8).cast();
64
65 alloc::mprotect(ptr, alloc::Prot::NoAccess);
66
67 SecretWords {
68 ptr,
69 size: 0,
70 capacity,
71 locks: AtomicIsize::new(0),
72 }
73 }
74 }
75
76 pub fn zeroed(size: usize) -> SecretWords {
77 unsafe {
78 let ptr = alloc::malloc(size * 8).cast();
79
80 memory::memzero(ptr.as_ptr() as *mut u8, size * 8);
81 alloc::mprotect(ptr, alloc::Prot::NoAccess);
82
83 SecretWords {
84 ptr,
85 size,
86 capacity: size,
87 locks: AtomicIsize::new(0),
88 }
89 }
90 }
91
92 pub fn is_empty(&self) -> bool {
93 self.size == 0
94 }
95
96 pub fn len(&self) -> usize {
97 self.size
98 }
99
100 pub fn capacity(&self) -> usize {
101 self.capacity
102 }
103
104 pub fn borrow(&self) -> Ref<'_> {
105 self.lock_read();
106 Ref { words: self }
107 }
108
109 pub fn borrow_mut(&mut self) -> RefMut<'_> {
110 self.lock_write();
111 RefMut { words: self }
112 }
113
114 pub fn locks(&self) -> isize {
115 self.locks.load(Ordering::Relaxed)
116 }
117
118 fn lock_read(&self) {
119 let locks = self.locks.fetch_add(1, Ordering::Relaxed);
120
121 assert!(locks >= 0);
122
123 if locks == 0 {
124 unsafe {
125 alloc::mprotect(self.ptr, alloc::Prot::ReadOnly);
126 }
127 }
128 }
129
130 fn unlock_read(&self) {
131 let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
132
133 assert!(locks > 0);
134
135 if locks == 1 {
136 unsafe {
137 alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
138 }
139 }
140 }
141
142 fn lock_write(&mut self) {
143 let locks = self.locks.fetch_sub(1, Ordering::Relaxed);
144
145 assert!(locks == 0);
146
147 unsafe {
148 alloc::mprotect(self.ptr, alloc::Prot::ReadWrite);
149 }
150 }
151
152 fn unlock_write(&mut self) {
153 let locks = self.locks.fetch_add(1, Ordering::Relaxed);
154
155 assert!(locks == -1);
156
157 unsafe {
158 alloc::mprotect(self.ptr, alloc::Prot::NoAccess);
159 }
160 }
161
162 fn as_mut_ptr(&mut self) -> *mut Word {
165 self.lock_write();
166
167 self.ptr.as_ptr()
168 }
169}
170
171unsafe impl Send for SecretWords {}
172
173unsafe impl Sync for SecretWords {}
174
175impl Drop for SecretWords {
176 fn drop(&mut self) {
177 unsafe { alloc::free(self.ptr) }
178 }
179}
180
181impl Clone for SecretWords {
182 fn clone(&self) -> Self {
183 unsafe {
184 let ptr = alloc::malloc(self.capacity * 8).cast::<Word>();
185
186 copy_nonoverlapping(self.borrow().as_words().as_ptr(), ptr.as_ptr(), self.capacity);
187 alloc::mprotect(ptr, alloc::Prot::NoAccess);
188
189 SecretWords {
190 ptr,
191 size: self.size,
192 capacity: self.capacity,
193 locks: AtomicIsize::new(0),
194 }
195 }
196 }
197}
198
199impl From<&mut [u8]> for SecretWords {
200 fn from(bytes: &mut [u8]) -> Self {
201 if bytes.len() % 8 != 0 {
202 warn!("Bytes not aligned to 8 bytes. Probably these are not the bytes you are looking for.");
203 }
204 unsafe {
205 let len = bytes.len() / 8;
206 let ptr = alloc::malloc(len * 8).cast();
207
208 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr() as *mut u8, len * 8);
209 memory::memzero(bytes.as_mut_ptr(), bytes.len());
210 alloc::mprotect(ptr, alloc::Prot::NoAccess);
211
212 SecretWords {
213 ptr,
214 size: len,
215 capacity: len,
216 locks: AtomicIsize::new(0),
217 }
218 }
219 }
220}
221
222impl From<&mut [Word]> for SecretWords {
223 fn from(words: &mut [Word]) -> Self {
224 unsafe {
225 let ptr = alloc::malloc(words.len() * 8).cast();
226
227 copy_nonoverlapping(words.as_ptr(), ptr.as_ptr(), words.len());
228 memory::memzero(words.as_mut_ptr() as *mut u8, words.len() * 8);
229 alloc::mprotect(ptr, alloc::Prot::NoAccess);
230
231 SecretWords {
232 ptr,
233 size: words.len(),
234 capacity: words.len(),
235 locks: AtomicIsize::new(0),
236 }
237 }
238 }
239}
240
241impl From<Vec<u8>> for SecretWords {
242 fn from(mut bytes: Vec<u8>) -> Self {
243 if bytes.len() % 8 != 0 {
244 warn!("Bytes not aligned to 8 bytes. Probably these are not the bytes you are looking for.");
245 }
246 unsafe {
247 let len = bytes.len() / 8;
248 let ptr = alloc::malloc(len * 8).cast();
249
250 copy_nonoverlapping(bytes.as_ptr(), ptr.as_ptr() as *mut u8, len * 8);
251 memory::memzero(bytes.as_mut_ptr(), bytes.len());
252 alloc::mprotect(ptr, alloc::Prot::NoAccess);
253
254 SecretWords {
255 ptr,
256 size: len,
257 capacity: len,
258 locks: AtomicIsize::new(0),
259 }
260 }
261 }
262}
263
264pub struct Ref<'a> {
265 words: &'a SecretWords,
266}
267
268impl Ref<'_> {
269 pub fn as_bytes(&self) -> &[u8] {
270 unsafe {
271 let words = slice::from_raw_parts(self.words.ptr.as_ptr(), self.words.size);
272 Word::words_to_bytes(words)
273 }
274 }
275
276 pub fn as_words(&self) -> &[Word] {
277 unsafe { slice::from_raw_parts(self.words.ptr.as_ptr(), self.words.size) }
278 }
279}
280
281impl Drop for Ref<'_> {
282 fn drop(&mut self) {
283 self.words.unlock_read()
284 }
285}
286
287impl Deref for Ref<'_> {
288 type Target = [u8];
289
290 fn deref(&self) -> &Self::Target {
291 unsafe { slice::from_raw_parts(self.words.ptr.as_ptr() as *const u8, self.words.size * 8) }
292 }
293}
294
295impl AsRef<[u8]> for Ref<'_> {
296 fn as_ref(&self) -> &[u8] {
297 self.as_bytes()
298 }
299}
300
301pub struct RefMut<'a> {
302 words: &'a mut SecretWords,
303}
304
305impl Drop for RefMut<'_> {
306 fn drop(&mut self) {
307 self.words.unlock_write()
308 }
309}
310
311impl Deref for RefMut<'_> {
312 type Target = [u8];
313
314 fn deref(&self) -> &Self::Target {
315 unsafe { slice::from_raw_parts(self.words.ptr.as_ptr() as *const u8, self.words.size * 8) }
316 }
317}
318
319impl DerefMut for RefMut<'_> {
320 fn deref_mut(&mut self) -> &mut Self::Target {
321 unsafe { slice::from_raw_parts_mut(self.words.ptr.as_ptr() as *mut u8, self.words.size * 8) }
322 }
323}
324
325impl AsRef<[u8]> for RefMut<'_> {
326 fn as_ref(&self) -> &[u8] {
327 unsafe { slice::from_raw_parts(self.words.ptr.as_ptr() as *const u8, self.words.size * 8) }
328 }
329}
330
331impl AsMut<[u8]> for RefMut<'_> {
332 fn as_mut(&mut self) -> &mut [u8] {
333 unsafe { slice::from_raw_parts_mut(self.words.ptr.as_ptr() as *mut u8, self.words.size * 8) }
334 }
335}
336
337pub struct SecureHHeapAllocator {
338 owned_memory: Vec<SecretWords>,
339 next_size: u32,
340 allocation_strategy: AllocationStrategy,
341}
342
343unsafe impl Allocator for SecureHHeapAllocator {
344 fn allocate_segment(&mut self, minimum_size: u32) -> (*mut u8, u32) {
345 let size = ::std::cmp::max(minimum_size, self.next_size);
346 let mut new_words = SecretWords::zeroed(size as usize);
347 let ptr = new_words.as_mut_ptr() as *mut u8;
348 self.owned_memory.push(new_words);
349
350 if let AllocationStrategy::GrowHeuristically = self.allocation_strategy {
351 self.next_size += size;
352 }
353 (ptr, size)
354 }
355
356 unsafe fn deallocate_segment(&mut self, _ptr: *mut u8, _word_size: u32, _words_used: u32) {
357 self.next_size = SUGGESTED_FIRST_SEGMENT_WORDS;
358 }
359}
360
361impl Default for SecureHHeapAllocator {
362 fn default() -> Self {
363 SecureHHeapAllocator {
364 owned_memory: Vec::new(),
365 next_size: SUGGESTED_FIRST_SEGMENT_WORDS,
366 allocation_strategy: SUGGESTED_ALLOCATION_STRATEGY,
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use byteorder::{BigEndian, ByteOrder};
374 use rand::{distributions, thread_rng, Rng};
375 use spectral::prelude::*;
376 use std::iter;
377
378 use super::*;
379
380 fn assert_slices_equal(actual: &[u8], expected: &[u8]) {
381 assert!(actual == expected)
382 }
383
384 fn word_from_u64(w: u64) -> Word {
385 capnp::word(
386 (w & 0xff) as u8,
387 ((w >> 8) & 0xff) as u8,
388 ((w >> 16) & 0xff) as u8,
389 ((w >> 24) & 0xff) as u8,
390 ((w >> 32) & 0xff) as u8,
391 ((w >> 40) & 0xff) as u8,
392 ((w >> 48) & 0xff) as u8,
393 ((w >> 56) & 0xff) as u8,
394 )
395 }
396
397 #[test]
398 fn test_borrow_read_only() {
399 let rng = thread_rng();
400 let mut source = rng
401 .sample_iter::<u8, _>(&distributions::Standard)
402 .filter(|w| *w != 0)
403 .take(200 * 8)
404 .collect::<Vec<u8>>();
405 let expected = source.clone();
406
407 for w in source.iter() {
408 assert_that(&w).is_not_equal_to(&0);
409 }
410
411 let guarded = SecretWords::from(source.as_mut_slice());
412
413 assert_that(&guarded.len()).is_equal_to(source.len() / 8);
414 assert_that(&guarded.borrow().as_words().len()).is_equal_to(source.len() / 8);
415
416 for w in source.iter() {
417 assert_that(&w).is_equal_to(&0);
418 }
419
420 assert_that(&guarded.locks()).is_equal_to(0);
421 assert_slices_equal(&guarded.borrow(), &expected);
422 assert_that(&guarded.locks()).is_equal_to(0);
423
424 {
425 let ref1 = guarded.borrow();
426 let ref2 = guarded.borrow();
427 let ref3 = guarded.borrow();
428
429 assert_that(&ref1.len()).is_equal_to(200 * 8);
430 assert_that(&guarded.locks()).is_equal_to(3);
431 assert_slices_equal(&ref1, &expected);
432 assert_slices_equal(&ref2, &expected);
433 assert_slices_equal(&ref3, &expected);
434 }
435 assert_that(&guarded.locks()).is_equal_to(0);
436 }
437
438 #[test]
439 fn test_zeroed() {
440 let guarded = SecretWords::zeroed(200);
441
442 assert_that(&guarded.len()).is_equal_to(200);
443 assert_that(&guarded.capacity()).is_equal_to(200);
444
445 {
446 let ref1 = guarded.borrow();
447
448 assert_that(&ref1.len()).is_equal_to(200 * 8);
449 for w in ref1.as_words() {
450 assert_that(&w).is_equal_to(&word_from_u64(0));
451 }
452 }
453 }
454
455 #[test]
456 fn test_borrow_read_write() {
457 let mut rng = thread_rng();
458 let mut source = iter::repeat(())
459 .map(|_| rng.sample(distributions::Standard))
460 .filter(|w| *w != 0)
461 .take(200 * 8)
462 .collect::<Vec<u8>>();
463 let source2 = iter::repeat(())
464 .map(|_| rng.sample(distributions::Standard))
465 .filter(|w| *w != 0)
466 .take(200 * 8)
467 .collect::<Vec<u8>>();
468 let expected = source.clone();
469 let expected2 = source2.clone();
470
471 for w in source.iter() {
472 assert_that(&w).is_not_equal_to(&0);
473 }
474
475 let mut guarded = SecretWords::from(source.as_mut_slice());
476
477 for w in source.iter() {
478 assert_that(&w).is_equal_to(&0);
479 }
480
481 assert_that(&guarded.locks()).is_equal_to(0);
482 assert_slices_equal(&guarded.borrow(), &expected);
483
484 guarded.borrow_mut().as_mut().copy_from_slice(&source2);
485
486 assert_that(&guarded.locks()).is_equal_to(0);
487 assert_slices_equal(&guarded.borrow(), &expected2);
488 }
489
490 #[test]
491 fn test_from_unaligned_source() {
492 let mut chunks = [0u8; 16];
493
494 BigEndian::write_u64(&mut chunks[0..8], 0x1234_5678_1234_5678);
495 BigEndian::write_u64(&mut chunks[8..16], 0xf0e1_d2c3_b4a5_9687);
496
497 let mut bytes1 = [0u8; 100 * 16 + 1];
498 let mut bytes2 = [0u8; 100 * 16 + 3];
499
500 for i in 0..100 {
501 bytes1[i * 16 + 1..i * 16 + 1 + 16].copy_from_slice(&chunks);
502 bytes2[i * 16 + 3..i * 16 + 3 + 16].copy_from_slice(&chunks);
503 }
504
505 let guarded1 = SecretWords::from(&mut bytes1[1..]);
506 let guarded2 = SecretWords::from(&mut bytes2[3..]);
507
508 for b in &bytes1[..] {
509 assert_that(b).is_equal_to(0);
510 }
511 for b in &bytes2[..] {
512 assert_that(b).is_equal_to(0);
513 }
514
515 assert_that(&guarded1.len()).is_equal_to(200);
516 assert_that(&guarded2.len()).is_equal_to(200);
517
518 for (idx, w) in guarded1.borrow().chunks(8).enumerate() {
519 if idx % 2 == 0 {
520 assert_that(&w).is_equal_to(&[0x12u8, 0x34u8, 0x56u8, 0x78u8, 0x12u8, 0x34u8, 0x56u8, 0x78u8][..]);
521 } else {
522 assert_that(&w).is_equal_to(&[0xf0u8, 0xe1u8, 0xd2u8, 0xc3u8, 0xb4u8, 0xa5u8, 0x96u8, 0x87u8][..]);
523 }
524 }
525 for (idx, w) in guarded2.borrow().chunks(8).enumerate() {
526 if idx % 2 == 0 {
527 assert_that(&w).is_equal_to(&[0x12u8, 0x34u8, 0x56u8, 0x78u8, 0x12u8, 0x34u8, 0x56u8, 0x78u8][..]);
528 } else {
529 assert_that(&w).is_equal_to(&[0xf0u8, 0xe1u8, 0xd2u8, 0xc3u8, 0xb4u8, 0xa5u8, 0x96u8, 0x87u8][..]);
530 }
531 }
532 }
533}