1#![allow(unsafe_code)] use crate::{P2PError, Result};
35use std::alloc::{Layout, alloc_zeroed, dealloc};
36use std::collections::VecDeque;
37use std::fmt;
38use std::ops::Deref;
39use std::ptr::{self, NonNull};
40use std::sync::Mutex;
41
42#[cfg(unix)]
43use libc::{mlock, munlock};
44
45#[cfg(windows)]
46use winapi::um::memoryapi::{VirtualLock, VirtualUnlock};
47
48const MAX_SECURE_ALLOCATION: usize = 65536;
50
51const DEFAULT_POOL_SIZE: usize = 1024 * 1024;
53
54const SECURE_ALIGNMENT: usize = 64;
56
57pub struct SecureMemory {
59 ptr: NonNull<u8>,
61 size: usize,
63 locked: bool,
65 layout: Layout,
67}
68
69unsafe impl Send for SecureMemory {}
71unsafe impl Sync for SecureMemory {}
73
74pub struct SecureVec {
76 memory: SecureMemory,
78 len: usize,
80}
81
82pub struct SecureString {
84 vec: SecureVec,
86}
87
88pub struct SecureMemoryPool {
90 available: Mutex<VecDeque<SecureMemory>>,
92 total_size: usize,
94 chunk_size: usize,
96 stats: Mutex<PoolStats>,
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct PoolStats {
103 pub total_allocations: u64,
105 pub total_deallocations: u64,
107 pub active_allocations: u64,
109 pub pool_hits: u64,
111 pub pool_misses: u64,
113 pub total_bytes_allocated: u64,
115 pub current_bytes_in_use: u64,
117}
118
119#[derive(Debug, Clone)]
121pub enum SecureMemoryError {
122 AllocationFailed(String),
124 LockingFailed(String),
126 InvalidParameters(String),
128 PoolExhausted,
130 NotSupported(String),
132}
133
134impl std::fmt::Display for SecureMemoryError {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 SecureMemoryError::AllocationFailed(msg) => write!(f, "Allocation failed: {msg}"),
138 SecureMemoryError::LockingFailed(msg) => write!(f, "Memory locking failed: {msg}"),
139 SecureMemoryError::InvalidParameters(msg) => write!(f, "Invalid parameters: {msg}"),
140 SecureMemoryError::PoolExhausted => write!(f, "Secure memory pool exhausted"),
141 SecureMemoryError::NotSupported(msg) => write!(f, "Operation not supported: {msg}"),
142 }
143 }
144}
145
146impl std::error::Error for SecureMemoryError {}
147
148impl SecureMemory {
149 pub fn new(size: usize) -> Result<Self> {
151 if size == 0 {
152 return Err(P2PError::Io(std::io::Error::new(
153 std::io::ErrorKind::InvalidInput,
154 "Cannot allocate zero-sized memory",
155 )));
156 }
157
158 if size > MAX_SECURE_ALLOCATION {
159 return Err(P2PError::Io(std::io::Error::new(
160 std::io::ErrorKind::InvalidInput,
161 format!("Allocation size {size} exceeds maximum {MAX_SECURE_ALLOCATION}"),
162 )));
163 }
164
165 let aligned_size = (size + SECURE_ALIGNMENT - 1) & !(SECURE_ALIGNMENT - 1);
167
168 let layout = Layout::from_size_align(aligned_size, SECURE_ALIGNMENT).map_err(|e| {
170 P2PError::Io(std::io::Error::new(
171 std::io::ErrorKind::InvalidInput,
172 format!("Invalid layout: {e}"),
173 ))
174 })?;
175
176 let ptr = unsafe { alloc_zeroed(layout) };
178 if ptr.is_null() {
179 return Err(P2PError::Io(std::io::Error::new(
180 std::io::ErrorKind::OutOfMemory,
181 "Memory allocation failed",
182 )));
183 }
184
185 let ptr = NonNull::new(ptr).ok_or_else(|| {
186 P2PError::Io(std::io::Error::new(
187 std::io::ErrorKind::OutOfMemory,
188 "Null pointer returned from allocator",
189 ))
190 })?;
191
192 let mut memory = Self {
193 ptr,
194 size: aligned_size,
195 locked: false,
196 layout,
197 };
198
199 if let Err(e) = memory.lock_memory() {
201 tracing::warn!("Failed to lock secure memory: {}", e);
202 }
203
204 Ok(memory)
205 }
206
207 pub fn from_slice(data: &[u8]) -> Result<Self> {
209 let mut memory = Self::new(data.len())?;
210 memory.as_mut_slice().copy_from_slice(data);
211 Ok(memory)
212 }
213
214 pub fn len(&self) -> usize {
216 self.size
217 }
218
219 pub fn is_empty(&self) -> bool {
221 self.size == 0
222 }
223
224 pub fn as_slice(&self) -> &[u8] {
226 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
227 }
228
229 pub fn as_mut_slice(&mut self) -> &mut [u8] {
231 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
232 }
233
234 pub fn constant_time_eq(&self, other: &SecureMemory) -> bool {
236 if self.size != other.size {
237 return false;
238 }
239
240 let a = self.as_slice();
241 let b = other.as_slice();
242
243 let mut result = 0u8;
245 for i in 0..self.size {
246 result |= a[i] ^ b[i];
247 }
248
249 result == 0
250 }
251
252 fn lock_memory(&mut self) -> Result<()> {
254 if self.locked {
255 return Ok(());
256 }
257
258 #[cfg(unix)]
259 {
260 let result = unsafe { mlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
261 if result != 0 {
262 return Err(P2PError::Io(std::io::Error::new(
263 std::io::ErrorKind::PermissionDenied,
264 "Failed to lock memory pages",
265 )));
266 }
267 }
268
269 #[cfg(windows)]
270 {
271 let result =
272 unsafe { VirtualLock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
273 if result == 0 {
274 return Err(P2PError::Io(std::io::Error::new(
275 std::io::ErrorKind::PermissionDenied,
276 "VirtualLock failed",
277 )));
278 }
279 }
280
281 #[cfg(not(any(unix, windows)))]
282 {
283 tracing::warn!("Memory locking not supported on this platform");
284 }
285
286 self.locked = true;
287 Ok(())
288 }
289
290 fn unlock_memory(&mut self) {
292 if !self.locked {
293 return;
294 }
295
296 #[cfg(unix)]
297 {
298 unsafe { munlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
299 }
300
301 #[cfg(windows)]
302 {
303 unsafe { VirtualUnlock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
304 }
305
306 self.locked = false;
307 }
308
309 pub fn zeroize(&mut self) {
311 unsafe {
312 ptr::write_volatile(self.ptr.as_ptr(), 0u8);
314
315 for i in 0..self.size {
317 ptr::write_volatile(self.ptr.as_ptr().add(i), 0u8);
318 }
319 }
320 }
321}
322
323impl Drop for SecureMemory {
324 fn drop(&mut self) {
325 self.zeroize();
327
328 self.unlock_memory();
330
331 unsafe {
333 dealloc(self.ptr.as_ptr(), self.layout);
334 }
335 }
336}
337
338impl SecureVec {
339 pub fn with_capacity(capacity: usize) -> Result<Self> {
341 let memory = SecureMemory::new(capacity)?;
342 Ok(Self { memory, len: 0 })
343 }
344
345 pub fn from_slice(data: &[u8]) -> Result<Self> {
347 let memory = SecureMemory::from_slice(data)?;
348 let len = data.len();
349 Ok(Self { memory, len })
350 }
351
352 pub fn len(&self) -> usize {
354 self.len
355 }
356
357 pub fn is_empty(&self) -> bool {
359 self.len == 0
360 }
361
362 pub fn capacity(&self) -> usize {
364 self.memory.len()
365 }
366
367 pub fn push(&mut self, value: u8) -> Result<()> {
369 if self.len >= self.capacity() {
370 return Err(P2PError::Io(std::io::Error::new(
371 std::io::ErrorKind::InvalidInput,
372 "SecureVec capacity exceeded",
373 )));
374 }
375
376 self.memory.as_mut_slice()[self.len] = value;
377 self.len += 1;
378 Ok(())
379 }
380
381 pub fn extend_from_slice(&mut self, data: &[u8]) -> Result<()> {
383 if self.len + data.len() > self.capacity() {
384 return Err(P2PError::Io(std::io::Error::new(
385 std::io::ErrorKind::InvalidInput,
386 "SecureVec capacity exceeded",
387 )));
388 }
389
390 self.memory.as_mut_slice()[self.len..self.len + data.len()].copy_from_slice(data);
391 self.len += data.len();
392 Ok(())
393 }
394
395 pub fn as_slice(&self) -> &[u8] {
397 &self.memory.as_slice()[..self.len]
398 }
399
400 pub fn clear(&mut self) {
402 self.memory.zeroize();
403 self.len = 0;
404 }
405}
406
407impl Deref for SecureVec {
408 type Target = [u8];
409
410 fn deref(&self) -> &Self::Target {
411 self.as_slice()
412 }
413}
414
415impl SecureString {
416 pub fn with_capacity(capacity: usize) -> Result<Self> {
418 let vec = SecureVec::with_capacity(capacity)?;
419 Ok(Self { vec })
420 }
421
422 pub fn from_plain_str(s: &str) -> Result<Self> {
424 let vec = SecureVec::from_slice(s.as_bytes())?;
425 Ok(Self { vec })
426 }
427
428 pub fn len(&self) -> usize {
430 self.vec.len()
431 }
432
433 pub fn is_empty(&self) -> bool {
435 self.vec.is_empty()
436 }
437
438 pub fn push(&mut self, ch: char) -> Result<()> {
440 let mut buffer = [0u8; 4];
441 let encoded = ch.encode_utf8(&mut buffer);
442 self.vec.extend_from_slice(encoded.as_bytes())
443 }
444
445 pub fn push_str(&mut self, s: &str) -> Result<()> {
447 self.vec.extend_from_slice(s.as_bytes())
448 }
449
450 pub fn as_str(&self) -> Result<&str> {
452 std::str::from_utf8(self.vec.as_slice()).map_err(|e| {
453 P2PError::Io(std::io::Error::new(
454 std::io::ErrorKind::InvalidData,
455 format!("Invalid UTF-8: {e}"),
456 ))
457 })
458 }
459
460 pub fn clear(&mut self) {
462 self.vec.clear();
463 }
464}
465
466impl fmt::Display for SecureString {
467 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468 match self.as_str() {
469 Ok(s) => write!(f, "{s}"),
470 Err(_) => write!(f, "<invalid UTF-8>"),
471 }
472 }
473}
474
475impl fmt::Debug for SecureString {
476 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477 write!(f, "SecureString[{}]", self.len())
478 }
479}
480
481impl SecureMemoryPool {
482 pub fn new(total_size: usize, chunk_size: usize) -> Result<Self> {
484 if chunk_size > total_size {
485 return Err(P2PError::Io(std::io::Error::new(
486 std::io::ErrorKind::InvalidInput,
487 "Chunk size cannot exceed total size",
488 )));
489 }
490
491 let pool = Self {
492 available: Mutex::new(VecDeque::new()),
493 total_size,
494 chunk_size,
495 stats: Mutex::new(PoolStats::default()),
496 };
497
498 pool.preallocate_chunks()?;
500
501 Ok(pool)
502 }
503
504 pub fn default_pool() -> Result<Self> {
506 Self::new(DEFAULT_POOL_SIZE, 4096)
507 }
508
509 pub fn allocate(&self, size: usize) -> Result<SecureMemory> {
511 if size > self.chunk_size {
512 if let Ok(mut stats) = self.stats.lock() {
514 stats.pool_misses += 1;
515 stats.total_allocations += 1;
516 stats.active_allocations += 1;
517 stats.total_bytes_allocated += size as u64;
518 stats.current_bytes_in_use += size as u64;
519 }
520 return SecureMemory::new(size);
521 }
522
523 {
525 if let Ok(mut available) = self.available.lock()
526 && let Some(memory) = available.pop_front()
527 {
528 if let Ok(mut stats) = self.stats.lock() {
529 stats.pool_hits += 1;
530 stats.total_allocations += 1;
531 stats.active_allocations += 1;
532 stats.current_bytes_in_use += memory.len() as u64;
533 }
534 return Ok(memory);
535 }
536 }
537
538 if let Ok(mut stats) = self.stats.lock() {
540 stats.pool_misses += 1;
541 stats.total_allocations += 1;
542 stats.active_allocations += 1;
543 stats.total_bytes_allocated += self.chunk_size as u64;
544 stats.current_bytes_in_use += self.chunk_size as u64;
545 }
546
547 SecureMemory::new(self.chunk_size)
548 }
549
550 pub fn deallocate(&self, mut memory: SecureMemory) {
552 memory.zeroize();
554
555 let memory_size = memory.len();
556
557 if memory_size == self.chunk_size {
558 if let Ok(mut available) = self.available.lock() {
560 available.push_back(memory);
561 }
562 }
563 if let Ok(mut stats) = self.stats.lock() {
566 stats.total_deallocations += 1;
567 stats.active_allocations -= 1;
568 stats.current_bytes_in_use -= memory_size as u64;
569 }
570 }
571
572 pub fn stats(&self) -> PoolStats {
574 self.stats.lock().map(|s| s.clone()).unwrap_or_default()
575 }
576
577 fn preallocate_chunks(&self) -> Result<()> {
579 let num_chunks = self.total_size / self.chunk_size;
580 if let Ok(mut available) = self.available.lock() {
581 for _ in 0..num_chunks {
582 let memory = SecureMemory::new(self.chunk_size)?;
583 available.push_back(memory);
584 }
585 }
586
587 Ok(())
588 }
589}
590
591static GLOBAL_POOL: std::sync::OnceLock<Result<SecureMemoryPool>> = std::sync::OnceLock::new();
593
594pub fn global_secure_pool() -> &'static SecureMemoryPool {
596 let result = GLOBAL_POOL.get_or_init(SecureMemoryPool::default_pool);
597 match result {
598 Ok(pool) => pool,
599 Err(_) => match SecureMemoryPool::new(DEFAULT_POOL_SIZE, 4096) {
600 Ok(pool) => {
601 let _ = GLOBAL_POOL.set(Ok(pool));
602 if let Some(Ok(pool)) = GLOBAL_POOL.get() {
603 pool
604 } else {
605 static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
607 once_cell::sync::OnceCell::new();
608 FALLBACK.get_or_init(|| SecureMemoryPool {
609 available: Mutex::new(VecDeque::new()),
610 total_size: DEFAULT_POOL_SIZE,
611 chunk_size: 4096,
612 stats: Mutex::new(PoolStats::default()),
613 })
614 }
615 }
616 Err(_) => {
617 static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
619 once_cell::sync::OnceCell::new();
620 FALLBACK.get_or_init(|| SecureMemoryPool {
621 available: Mutex::new(VecDeque::new()),
622 total_size: DEFAULT_POOL_SIZE,
623 chunk_size: 4096,
624 stats: Mutex::new(PoolStats::default()),
625 })
626 }
627 },
628 }
629}
630
631pub fn allocate_secure(size: usize) -> Result<SecureMemory> {
633 global_secure_pool().allocate(size)
634}
635
636pub fn secure_vec_with_capacity(capacity: usize) -> Result<SecureVec> {
638 let memory = global_secure_pool().allocate(capacity)?;
639 Ok(SecureVec { memory, len: 0 })
640}
641
642pub fn secure_string_with_capacity(capacity: usize) -> Result<SecureString> {
644 let vec = secure_vec_with_capacity(capacity)?;
645 Ok(SecureString { vec })
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
653 fn test_secure_memory_basic() {
654 let mut memory = SecureMemory::new(1024).unwrap();
655
656 assert_eq!(memory.len(), 1024);
658 assert!(!memory.is_empty());
659
660 memory.as_mut_slice()[0] = 42;
662 assert_eq!(memory.as_slice()[0], 42);
663
664 memory.zeroize();
666 assert_eq!(memory.as_slice()[0], 0);
667 }
668
669 #[test]
670 fn test_secure_memory_constant_time_comparison() {
671 let memory1 = SecureMemory::from_slice(b"hello").unwrap();
672 let memory2 = SecureMemory::from_slice(b"hello").unwrap();
673 let memory3 = SecureMemory::from_slice(b"world").unwrap();
674
675 assert!(memory1.constant_time_eq(&memory2));
676 assert!(!memory1.constant_time_eq(&memory3));
677 }
678
679 #[test]
680 fn test_secure_vec() {
681 let mut vec = SecureVec::with_capacity(100).unwrap();
682
683 vec.push(1).unwrap();
685 vec.push(2).unwrap();
686 vec.extend_from_slice(&[3, 4, 5]).unwrap();
687
688 assert_eq!(vec.len(), 5);
689 assert_eq!(vec.as_slice(), &[1, 2, 3, 4, 5]);
690
691 vec.clear();
693 assert_eq!(vec.len(), 0);
694 assert!(vec.is_empty());
695 }
696
697 #[test]
698 fn test_secure_string() {
699 let mut string = SecureString::with_capacity(100).unwrap();
700
701 string.push('H').unwrap();
703 string.push_str("ello").unwrap();
704
705 assert_eq!(string.as_str().unwrap(), "Hello");
706 assert_eq!(string.len(), 5);
707
708 string.clear();
710 assert_eq!(string.len(), 0);
711 assert!(string.is_empty());
712 }
713
714 #[test]
715 fn test_secure_memory_pool() {
716 let pool = SecureMemoryPool::new(8192, 1024).unwrap();
717
718 let memory1 = pool.allocate(512).unwrap();
720 let memory2 = pool.allocate(1024).unwrap();
721
722 let stats = pool.stats();
724 assert_eq!(stats.total_allocations, 2);
725 assert_eq!(stats.active_allocations, 2);
726
727 pool.deallocate(memory1);
729 pool.deallocate(memory2);
730
731 let stats = pool.stats();
732 assert_eq!(stats.total_deallocations, 2);
733 assert_eq!(stats.active_allocations, 0);
734 }
735
736 #[test]
737 fn test_global_pool() {
738 let memory = allocate_secure(256).unwrap();
739 assert_eq!(memory.len(), 256);
740
741 let vec = secure_vec_with_capacity(128).unwrap();
742 assert_eq!(vec.capacity(), 128);
743
744 let string = secure_string_with_capacity(64).unwrap();
745 assert_eq!(string.vec.capacity(), 64);
746 }
747}