1#![allow(unsafe_code)] use crate::{P2PError, Result};
35use std::alloc::{Layout, alloc_zeroed, dealloc};
36use std::collections::VecDeque;
37use std::fmt;
38use std::ops::Deref;
39use std::ptr::{self, NonNull};
40use std::sync::Mutex;
41
42#[cfg(unix)]
43use libc::{mlock, munlock};
44
45#[cfg(windows)]
46use winapi::um::memoryapi::{VirtualLock, VirtualUnlock};
47
48const MAX_SECURE_ALLOCATION: usize = 65536;
50
51const DEFAULT_POOL_SIZE: usize = 1024 * 1024;
53
54const SECURE_ALIGNMENT: usize = 64;
56
57pub struct SecureMemory {
59 ptr: NonNull<u8>,
61 size: usize,
63 data_len: usize,
65 locked: bool,
67 layout: Layout,
69}
70
71unsafe impl Send for SecureMemory {}
73unsafe impl Sync for SecureMemory {}
75
76pub struct SecureVec {
78 memory: SecureMemory,
80 len: usize,
82}
83
84pub struct SecureString {
86 vec: SecureVec,
88}
89
90pub struct SecureMemoryPool {
92 available: Mutex<VecDeque<SecureMemory>>,
94 total_size: usize,
96 chunk_size: usize,
98 stats: Mutex<PoolStats>,
100}
101
102#[derive(Debug, Clone, Default)]
104pub struct PoolStats {
105 pub total_allocations: u64,
107 pub total_deallocations: u64,
109 pub active_allocations: u64,
111 pub pool_hits: u64,
113 pub pool_misses: u64,
115 pub total_bytes_allocated: u64,
117 pub current_bytes_in_use: u64,
119}
120
121#[derive(Debug, Clone)]
123pub enum SecureMemoryError {
124 AllocationFailed(String),
126 LockingFailed(String),
128 InvalidParameters(String),
130 PoolExhausted,
132 NotSupported(String),
134}
135
136impl std::fmt::Display for SecureMemoryError {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 match self {
139 SecureMemoryError::AllocationFailed(msg) => write!(f, "Allocation failed: {msg}"),
140 SecureMemoryError::LockingFailed(msg) => write!(f, "Memory locking failed: {msg}"),
141 SecureMemoryError::InvalidParameters(msg) => write!(f, "Invalid parameters: {msg}"),
142 SecureMemoryError::PoolExhausted => write!(f, "Secure memory pool exhausted"),
143 SecureMemoryError::NotSupported(msg) => write!(f, "Operation not supported: {msg}"),
144 }
145 }
146}
147
148impl std::error::Error for SecureMemoryError {}
149
150impl SecureMemory {
151 pub fn new(size: usize) -> Result<Self> {
153 if size == 0 {
154 return Err(P2PError::Io(std::io::Error::new(
155 std::io::ErrorKind::InvalidInput,
156 "Cannot allocate zero-sized memory",
157 )));
158 }
159
160 if size > MAX_SECURE_ALLOCATION {
161 return Err(P2PError::Io(std::io::Error::new(
162 std::io::ErrorKind::InvalidInput,
163 format!("Allocation size {size} exceeds maximum {MAX_SECURE_ALLOCATION}"),
164 )));
165 }
166
167 let aligned_size = (size + SECURE_ALIGNMENT - 1) & !(SECURE_ALIGNMENT - 1);
169
170 let layout = Layout::from_size_align(aligned_size, SECURE_ALIGNMENT).map_err(|e| {
172 P2PError::Io(std::io::Error::new(
173 std::io::ErrorKind::InvalidInput,
174 format!("Invalid layout: {e}"),
175 ))
176 })?;
177
178 let ptr = unsafe { alloc_zeroed(layout) };
180 if ptr.is_null() {
181 return Err(P2PError::Io(std::io::Error::new(
182 std::io::ErrorKind::OutOfMemory,
183 "Memory allocation failed",
184 )));
185 }
186
187 let ptr = NonNull::new(ptr).ok_or_else(|| {
188 P2PError::Io(std::io::Error::new(
189 std::io::ErrorKind::OutOfMemory,
190 "Null pointer returned from allocator",
191 ))
192 })?;
193
194 let mut memory = Self {
195 ptr,
196 size: aligned_size,
197 data_len: size,
198 locked: false,
199 layout,
200 };
201
202 if let Err(e) = memory.lock_memory() {
204 tracing::warn!("Failed to lock secure memory: {}", e);
205 }
206
207 Ok(memory)
208 }
209
210 pub fn from_slice(data: &[u8]) -> Result<Self> {
212 let mut memory = Self::new(data.len())?;
213 memory.as_mut_slice()[..data.len()].copy_from_slice(data);
214 Ok(memory)
215 }
216
217 pub fn len(&self) -> usize {
219 self.size
220 }
221
222 pub fn is_empty(&self) -> bool {
224 self.size == 0
225 }
226
227 pub fn as_slice(&self) -> &[u8] {
229 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.data_len) }
230 }
231
232 pub fn as_mut_slice(&mut self) -> &mut [u8] {
234 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.data_len) }
235 }
236
237 pub fn as_allocated_slice(&self) -> &[u8] {
239 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
240 }
241
242 pub fn as_allocated_mut_slice(&mut self) -> &mut [u8] {
244 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
245 }
246
247 pub fn constant_time_eq(&self, other: &SecureMemory) -> bool {
249 if self.data_len != other.data_len {
250 return false;
251 }
252
253 let a = self.as_slice();
254 let b = other.as_slice();
255
256 let mut result = 0u8;
258 for i in 0..self.data_len {
259 result |= a[i] ^ b[i];
260 }
261
262 result == 0
263 }
264
265 fn lock_memory(&mut self) -> Result<()> {
267 if self.locked {
268 return Ok(());
269 }
270
271 #[cfg(unix)]
272 {
273 let result = unsafe { mlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
274 if result != 0 {
275 return Err(P2PError::Io(std::io::Error::new(
276 std::io::ErrorKind::PermissionDenied,
277 "Failed to lock memory pages",
278 )));
279 }
280 }
281
282 #[cfg(windows)]
283 {
284 let result =
285 unsafe { VirtualLock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
286 if result == 0 {
287 return Err(P2PError::Io(std::io::Error::new(
288 std::io::ErrorKind::PermissionDenied,
289 "VirtualLock failed",
290 )));
291 }
292 }
293
294 #[cfg(not(any(unix, windows)))]
295 {
296 tracing::warn!("Memory locking not supported on this platform");
297 }
298
299 self.locked = true;
300 Ok(())
301 }
302
303 fn unlock_memory(&mut self) {
305 if !self.locked {
306 return;
307 }
308
309 #[cfg(unix)]
310 {
311 unsafe { munlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
312 }
313
314 #[cfg(windows)]
315 {
316 unsafe { VirtualUnlock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
317 }
318
319 self.locked = false;
320 }
321
322 pub fn zeroize(&mut self) {
324 unsafe {
325 ptr::write_volatile(self.ptr.as_ptr(), 0u8);
327
328 for i in 0..self.size {
330 ptr::write_volatile(self.ptr.as_ptr().add(i), 0u8);
331 }
332 }
333 }
334}
335
336impl Drop for SecureMemory {
337 fn drop(&mut self) {
338 self.zeroize();
340
341 self.unlock_memory();
343
344 unsafe {
346 dealloc(self.ptr.as_ptr(), self.layout);
347 }
348 }
349}
350
351impl SecureVec {
352 pub fn with_capacity(capacity: usize) -> Result<Self> {
354 let memory = SecureMemory::new(capacity)?;
355 Ok(Self { memory, len: 0 })
356 }
357
358 pub fn from_slice(data: &[u8]) -> Result<Self> {
360 let memory = SecureMemory::from_slice(data)?;
361 let len = data.len();
362 Ok(Self { memory, len })
363 }
364
365 pub fn len(&self) -> usize {
367 self.len
368 }
369
370 pub fn is_empty(&self) -> bool {
372 self.len == 0
373 }
374
375 pub fn capacity(&self) -> usize {
377 self.memory.len()
378 }
379
380 pub fn push(&mut self, value: u8) -> Result<()> {
382 if self.len >= self.capacity() {
383 return Err(P2PError::Io(std::io::Error::new(
384 std::io::ErrorKind::InvalidInput,
385 "SecureVec capacity exceeded",
386 )));
387 }
388
389 self.memory.as_allocated_mut_slice()[self.len] = value;
390 self.len += 1;
391 Ok(())
392 }
393
394 pub fn extend_from_slice(&mut self, data: &[u8]) -> Result<()> {
396 if self.len + data.len() > self.capacity() {
397 return Err(P2PError::Io(std::io::Error::new(
398 std::io::ErrorKind::InvalidInput,
399 "SecureVec capacity exceeded",
400 )));
401 }
402
403 self.memory.as_allocated_mut_slice()[self.len..self.len + data.len()].copy_from_slice(data);
404 self.len += data.len();
405 Ok(())
406 }
407
408 pub fn as_slice(&self) -> &[u8] {
410 &self.memory.as_slice()[..self.len]
411 }
412
413 pub fn clear(&mut self) {
415 self.memory.zeroize();
416 self.len = 0;
417 }
418}
419
420impl Deref for SecureVec {
421 type Target = [u8];
422
423 fn deref(&self) -> &Self::Target {
424 self.as_slice()
425 }
426}
427
428impl SecureString {
429 pub fn with_capacity(capacity: usize) -> Result<Self> {
431 let vec = SecureVec::with_capacity(capacity)?;
432 Ok(Self { vec })
433 }
434
435 pub fn from_plain_str(s: &str) -> Result<Self> {
437 let vec = SecureVec::from_slice(s.as_bytes())?;
438 Ok(Self { vec })
439 }
440
441 pub fn len(&self) -> usize {
443 self.vec.len()
444 }
445
446 pub fn is_empty(&self) -> bool {
448 self.vec.is_empty()
449 }
450
451 pub fn push(&mut self, ch: char) -> Result<()> {
453 let mut buffer = [0u8; 4];
454 let encoded = ch.encode_utf8(&mut buffer);
455 self.vec.extend_from_slice(encoded.as_bytes())
456 }
457
458 pub fn push_str(&mut self, s: &str) -> Result<()> {
460 self.vec.extend_from_slice(s.as_bytes())
461 }
462
463 pub fn as_str(&self) -> Result<&str> {
465 std::str::from_utf8(self.vec.as_slice()).map_err(|e| {
466 P2PError::Io(std::io::Error::new(
467 std::io::ErrorKind::InvalidData,
468 format!("Invalid UTF-8: {e}"),
469 ))
470 })
471 }
472
473 pub fn clear(&mut self) {
475 self.vec.clear();
476 }
477}
478
479impl fmt::Display for SecureString {
480 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
481 match self.as_str() {
482 Ok(s) => write!(f, "{s}"),
483 Err(_) => write!(f, "<invalid UTF-8>"),
484 }
485 }
486}
487
488impl fmt::Debug for SecureString {
489 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
490 write!(f, "SecureString[{}]", self.len())
491 }
492}
493
494impl SecureMemoryPool {
495 pub fn new(total_size: usize, chunk_size: usize) -> Result<Self> {
497 if chunk_size > total_size {
498 return Err(P2PError::Io(std::io::Error::new(
499 std::io::ErrorKind::InvalidInput,
500 "Chunk size cannot exceed total size",
501 )));
502 }
503
504 let pool = Self {
505 available: Mutex::new(VecDeque::new()),
506 total_size,
507 chunk_size,
508 stats: Mutex::new(PoolStats::default()),
509 };
510
511 pool.preallocate_chunks()?;
513
514 Ok(pool)
515 }
516
517 pub fn default_pool() -> Result<Self> {
519 Self::new(DEFAULT_POOL_SIZE, 4096)
520 }
521
522 pub fn allocate(&self, size: usize) -> Result<SecureMemory> {
524 if size > self.chunk_size {
525 if let Ok(mut stats) = self.stats.lock() {
527 stats.pool_misses += 1;
528 stats.total_allocations += 1;
529 stats.active_allocations += 1;
530 stats.total_bytes_allocated += size as u64;
531 stats.current_bytes_in_use += size as u64;
532 }
533 return SecureMemory::new(size);
534 }
535
536 {
538 if let Ok(mut available) = self.available.lock()
539 && let Some(memory) = available.pop_front()
540 {
541 if let Ok(mut stats) = self.stats.lock() {
542 stats.pool_hits += 1;
543 stats.total_allocations += 1;
544 stats.active_allocations += 1;
545 stats.current_bytes_in_use += memory.len() as u64;
546 }
547 return Ok(memory);
548 }
549 }
550
551 if let Ok(mut stats) = self.stats.lock() {
553 stats.pool_misses += 1;
554 stats.total_allocations += 1;
555 stats.active_allocations += 1;
556 stats.total_bytes_allocated += self.chunk_size as u64;
557 stats.current_bytes_in_use += self.chunk_size as u64;
558 }
559
560 SecureMemory::new(self.chunk_size)
561 }
562
563 pub fn deallocate(&self, mut memory: SecureMemory) {
565 memory.zeroize();
567
568 let memory_size = memory.len();
569
570 if memory_size == self.chunk_size {
571 if let Ok(mut available) = self.available.lock() {
573 available.push_back(memory);
574 }
575 }
576 if let Ok(mut stats) = self.stats.lock() {
579 stats.total_deallocations += 1;
580 stats.active_allocations -= 1;
581 stats.current_bytes_in_use -= memory_size as u64;
582 }
583 }
584
585 pub fn stats(&self) -> PoolStats {
587 self.stats.lock().map(|s| s.clone()).unwrap_or_default()
588 }
589
590 fn preallocate_chunks(&self) -> Result<()> {
592 let num_chunks = self.total_size / self.chunk_size;
593 if let Ok(mut available) = self.available.lock() {
594 for _ in 0..num_chunks {
595 let memory = SecureMemory::new(self.chunk_size)?;
596 available.push_back(memory);
597 }
598 }
599
600 Ok(())
601 }
602}
603
604static GLOBAL_POOL: std::sync::OnceLock<Result<SecureMemoryPool>> = std::sync::OnceLock::new();
606
607pub fn global_secure_pool() -> &'static SecureMemoryPool {
609 let result = GLOBAL_POOL.get_or_init(SecureMemoryPool::default_pool);
610 match result {
611 Ok(pool) => pool,
612 Err(_) => match SecureMemoryPool::new(DEFAULT_POOL_SIZE, 4096) {
613 Ok(pool) => {
614 let _ = GLOBAL_POOL.set(Ok(pool));
615 if let Some(Ok(pool)) = GLOBAL_POOL.get() {
616 pool
617 } else {
618 static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
620 once_cell::sync::OnceCell::new();
621 FALLBACK.get_or_init(|| SecureMemoryPool {
622 available: Mutex::new(VecDeque::new()),
623 total_size: DEFAULT_POOL_SIZE,
624 chunk_size: 4096,
625 stats: Mutex::new(PoolStats::default()),
626 })
627 }
628 }
629 Err(_) => {
630 static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
632 once_cell::sync::OnceCell::new();
633 FALLBACK.get_or_init(|| SecureMemoryPool {
634 available: Mutex::new(VecDeque::new()),
635 total_size: DEFAULT_POOL_SIZE,
636 chunk_size: 4096,
637 stats: Mutex::new(PoolStats::default()),
638 })
639 }
640 },
641 }
642}
643
644pub fn allocate_secure(size: usize) -> Result<SecureMemory> {
646 global_secure_pool().allocate(size)
647}
648
649pub fn secure_vec_with_capacity(capacity: usize) -> Result<SecureVec> {
651 let memory = global_secure_pool().allocate(capacity)?;
652 Ok(SecureVec { memory, len: 0 })
653}
654
655pub fn secure_string_with_capacity(capacity: usize) -> Result<SecureString> {
657 let vec = secure_vec_with_capacity(capacity)?;
658 Ok(SecureString { vec })
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_secure_memory_basic() {
667 let mut memory = SecureMemory::new(1024).unwrap();
668
669 assert_eq!(memory.len(), 1024);
671 assert!(!memory.is_empty());
672
673 memory.as_mut_slice()[0] = 42;
675 assert_eq!(memory.as_slice()[0], 42);
676
677 memory.zeroize();
679 assert_eq!(memory.as_slice()[0], 0);
680 }
681
682 #[test]
683 fn test_secure_memory_constant_time_comparison() {
684 let memory1 = SecureMemory::from_slice(b"hello").unwrap();
685 let memory2 = SecureMemory::from_slice(b"hello").unwrap();
686 let memory3 = SecureMemory::from_slice(b"world").unwrap();
687
688 assert!(memory1.constant_time_eq(&memory2));
689 assert!(!memory1.constant_time_eq(&memory3));
690 }
691
692 #[test]
693 fn test_secure_vec() {
694 let mut vec = SecureVec::with_capacity(100).unwrap();
695
696 vec.push(1).unwrap();
698 vec.push(2).unwrap();
699 vec.extend_from_slice(&[3, 4, 5]).unwrap();
700
701 assert_eq!(vec.len(), 5);
702 assert_eq!(vec.as_slice(), &[1, 2, 3, 4, 5]);
703
704 vec.clear();
706 assert_eq!(vec.len(), 0);
707 assert!(vec.is_empty());
708 }
709
710 #[test]
711 fn test_secure_string() {
712 let mut string = SecureString::with_capacity(100).unwrap();
713
714 string.push('H').unwrap();
716 string.push_str("ello").unwrap();
717
718 assert_eq!(string.as_str().unwrap(), "Hello");
719 assert_eq!(string.len(), 5);
720
721 string.clear();
723 assert_eq!(string.len(), 0);
724 assert!(string.is_empty());
725 }
726
727 #[test]
728 fn test_secure_memory_pool() {
729 let pool = SecureMemoryPool::new(8192, 1024).unwrap();
730
731 let memory1 = pool.allocate(512).unwrap();
733 let memory2 = pool.allocate(1024).unwrap();
734
735 let stats = pool.stats();
737 assert_eq!(stats.total_allocations, 2);
738 assert_eq!(stats.active_allocations, 2);
739
740 pool.deallocate(memory1);
742 pool.deallocate(memory2);
743
744 let stats = pool.stats();
745 assert_eq!(stats.total_deallocations, 2);
746 assert_eq!(stats.active_allocations, 0);
747 }
748
749 #[test]
750 fn test_global_pool() {
751 let memory = allocate_secure(256).unwrap();
752 println!(
753 "allocate_secure(256) returned memory.len() = {}",
754 memory.len()
755 );
756 assert_eq!(memory.len(), 4096); let vec = secure_vec_with_capacity(128).unwrap();
760 println!(
761 "secure_vec_with_capacity(128) returned vec.capacity() = {}",
762 vec.capacity()
763 );
764 assert_eq!(vec.capacity(), 4096); let string = secure_string_with_capacity(64).unwrap();
767 println!(
768 "secure_string_with_capacity(64) returned string.vec.capacity() = {}",
769 string.vec.capacity()
770 );
771 assert_eq!(string.vec.capacity(), 4096); }
773}