t_rust_less_lib/memguard/
memory.rs

1use std::ptr;
2
3// -- memcmp --
4
5/// Secure `memeq`.
6///
7/// # Safety
8///
9/// `b1` and `b2` have to point to a memory section of at least `len` bytes
10#[inline(never)]
11pub unsafe fn memeq(b1: *const u8, b2: *const u8, len: usize) -> bool {
12  (0..len as isize)
13    .map(|i| ptr::read_volatile(b1.offset(i)) ^ ptr::read_volatile(b2.offset(i)))
14    .fold(0, |sum, next| sum | next)
15    .eq(&0)
16}
17
18/// Secure `memcmp`.
19///
20/// # Safety
21///
22/// `b1` and `b2` have to point to a memory section of at least `len` bytes
23#[inline(never)]
24#[allow(dead_code)]
25pub unsafe fn memcmp(b1: *const u8, b2: *const u8, len: usize) -> i32 {
26  let mut res = 0;
27  for i in (0..len as isize).rev() {
28    let diff = i32::from(ptr::read_volatile(b1.offset(i))) - i32::from(ptr::read_volatile(b2.offset(i)));
29    res = (res & (((diff - 1) & !diff) >> 8)) | diff;
30  }
31  ((res - 1) >> 8) + (res >> 8) + 1
32}
33
34// -- memset / memzero --
35
36/// General `memset`.
37///
38/// # Safety
39///
40/// `s` has to point to a memory section of at least `n` bytes
41#[cfg(feature = "nightly")]
42#[cfg(any(not(apple), not(feature = "use_os")))]
43#[inline(never)]
44pub unsafe fn memset(s: *mut u8, c: u8, n: usize) {
45  core::intrinsics::volatile_set_memory(s, c, n);
46}
47
48/// General `memset`.
49///
50/// # Safety
51///
52/// `s` has to point to a memory section of at least `n` bytes
53#[cfg(not(feature = "nightly"))]
54#[cfg(any(not(apple), not(feature = "use_os")))]
55#[inline(never)]
56pub unsafe fn memset(s: *mut u8, c: u8, n: usize) {
57  for i in 0..n {
58    ptr::write_volatile(s.add(i), c);
59  }
60}
61
62/// Call `memset_s`.
63#[cfg(all(apple, feature = "use_os"))]
64pub unsafe fn memset(s: *mut u8, c: u8, n: usize) {
65  use libc::{c_int, c_void};
66  use mach_o_sys::ranlib::{errno_t, rsize_t};
67
68  extern "C" {
69    fn memset_s(s: *mut c_void, smax: rsize_t, c: c_int, n: rsize_t) -> errno_t;
70  }
71
72  if n > 0 && memset_s(s as *mut c_void, n as _, c as _, n as _) != 0 {
73    std::process::abort()
74  }
75}
76
77/// General `memzero`.
78///
79/// # Safety
80///
81/// `dest` has to point to a memory section of at least `n` bytes
82#[cfg(any(
83  not(any(all(windows, not(target_env = "msvc")), freebsdlike, netbsdlike)),
84  not(feature = "use_os")
85))]
86#[inline]
87pub unsafe fn memzero(dest: *mut u8, n: usize) {
88  memset(dest, 0, n);
89}
90
91/// Call `explicit_bzero`.
92///
93/// # Safety
94///
95/// `dest` has to point to a memory section of at least `n` bytes
96#[cfg(all(any(freebsdlike, netbsdlike), feature = "use_os"))]
97pub unsafe fn memzero(dest: *mut u8, n: usize) {
98  extern "C" {
99    fn explicit_bzero(s: *mut libc::c_void, n: libc::size_t);
100  }
101  explicit_bzero(dest as *mut libc::c_void, n);
102}
103
104/// Call `SecureZeroMemory`.
105///
106/// # Safety
107///
108/// `s` has to point to a memory section of at least `n` bytes
109#[cfg(all(windows, not(target_env = "msvc"), feature = "use_os"))]
110pub unsafe fn memzero(s: *mut u8, n: usize) {
111  extern "system" {
112    fn RtlSecureZeroMemory(ptr: winapi::shared::ntdef::PVOID, cnt: winapi::shared::basetsd::SIZE_T);
113  }
114  RtlSecureZeroMemory(s as winapi::shared::ntdef::PVOID, n as winapi::shared::basetsd::SIZE_T);
115}
116
117/// Unix `mlock`.
118///
119/// # Safety
120///
121/// `addr` has to point to a memory section of at least `len` bytes
122#[cfg(unix)]
123pub unsafe fn mlock(addr: *mut u8, len: usize) -> bool {
124  #[cfg(target_os = "linux")]
125  libc::madvise(addr as *mut ::libc::c_void, len, ::libc::MADV_DONTDUMP);
126
127  #[cfg(freebsdlike)]
128  libc::madvise(addr as *mut ::libc::c_void, len, ::libc::MADV_NOCORE);
129
130  libc::mlock(addr as *mut ::libc::c_void, len) == 0
131}
132
133/// Windows `VirtualLock`.
134///
135/// # Safety
136///
137/// `addr` has to point to a memory section of at least `len` bytes
138#[cfg(windows)]
139pub unsafe fn mlock(addr: *mut u8, len: usize) -> bool {
140  winapi::um::memoryapi::VirtualLock(
141    addr as ::winapi::shared::minwindef::LPVOID,
142    len as ::winapi::shared::basetsd::SIZE_T,
143  ) != 0
144}
145
146/// Unix `munlock`.
147///
148/// # Safety
149///
150/// `addr` has to point to a memory section of at least `len` bytes
151#[cfg(unix)]
152pub unsafe fn munlock(addr: *mut u8, len: usize) -> bool {
153  memzero(addr, len);
154
155  #[cfg(target_os = "linux")]
156  libc::madvise(addr as *mut ::libc::c_void, len, ::libc::MADV_DODUMP);
157
158  #[cfg(freebsdlike)]
159  libc::madvise(addr as *mut ::libc::c_void, len, ::libc::MADV_CORE);
160
161  libc::munlock(addr as *mut ::libc::c_void, len) == 0
162}
163
164/// Windows `VirtualUnlock`.
165#[cfg(windows)]
166pub unsafe fn munlock(addr: *mut u8, len: usize) -> bool {
167  memzero(addr, len);
168  winapi::um::memoryapi::VirtualUnlock(
169    addr as ::winapi::shared::minwindef::LPVOID,
170    len as ::winapi::shared::basetsd::SIZE_T,
171  ) != 0
172}
173
174#[cfg(test)]
175mod tests {
176  use std::cmp;
177  use std::mem;
178
179  use quickcheck::quickcheck;
180
181  use super::*;
182
183  #[test]
184  fn memzero_test() {
185    unsafe {
186      let mut x: [usize; 16] = [1; 16];
187      memzero(x.as_mut_ptr() as *mut u8, mem::size_of_val(&x));
188      assert_eq!(x, [0; 16]);
189      x.clone_from_slice(&[1; 16]);
190      assert_eq!(x, [1; 16]);
191      memzero(x[1..11].as_mut_ptr() as *mut u8, 10 * mem::size_of_val(&x[0]));
192      assert_eq!(x, [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]);
193    }
194  }
195
196  #[test]
197  #[cfg(unix)]
198  fn memeq_test() {
199    #[allow(clippy::needless_pass_by_value)]
200    fn check_memeq(x: Vec<u8>, y: Vec<u8>) -> bool {
201      unsafe {
202        let memsec_output = memeq(x.as_ptr(), y.as_ptr(), cmp::min(x.len(), y.len()));
203        let libc_output = libc::memcmp(
204          x.as_ptr() as *const libc::c_void,
205          y.as_ptr() as *const libc::c_void,
206          cmp::min(x.len(), y.len()),
207        ) == 0;
208        memsec_output == libc_output
209      }
210    }
211    quickcheck(check_memeq as fn(Vec<u8>, Vec<u8>) -> bool);
212  }
213
214  #[test]
215  #[cfg(unix)]
216  fn memcmp_test() {
217    #[allow(clippy::needless_pass_by_value)]
218    fn check_memcmp(x: Vec<u8>, y: Vec<u8>) -> bool {
219      unsafe {
220        let memsec_output = memcmp(x.as_ptr(), y.as_ptr(), cmp::min(x.len(), y.len()));
221        let libc_output = libc::memcmp(
222          x.as_ptr() as *const libc::c_void,
223          y.as_ptr() as *const libc::c_void,
224          cmp::min(x.len(), y.len()),
225        );
226        (memsec_output > 0) == (libc_output > 0)
227          && (memsec_output < 0) == (libc_output < 0)
228          && (memsec_output == 0) == (libc_output == 0)
229      }
230    }
231    quickcheck(check_memcmp as fn(Vec<u8>, Vec<u8>) -> bool);
232  }
233
234  #[test]
235  fn mlock_munlock_test() {
236    unsafe {
237      let mut x = [1; 16];
238
239      assert!(mlock(x.as_mut_ptr(), mem::size_of_val(&x)));
240      assert!(munlock(x.as_mut_ptr(), mem::size_of_val(&x)));
241      assert_eq!(x, [0; 16]);
242    }
243  }
244}