1#![allow(unsafe_code)] use crate::{P2PError, Result};
35use std::alloc::{Layout, alloc_zeroed, dealloc};
36use std::collections::VecDeque;
37use std::fmt;
38use std::ops::Deref;
39use std::ptr::{self, NonNull};
40use std::sync::Mutex;
41use subtle::ConstantTimeEq;
42
43#[cfg(unix)]
44use libc::{mlock, munlock};
45
46#[cfg(windows)]
47use winapi::um::memoryapi::{VirtualLock, VirtualUnlock};
48
49const MAX_SECURE_ALLOCATION: usize = 65536;
51
52const DEFAULT_POOL_SIZE: usize = 1024 * 1024;
54
55const SECURE_ALIGNMENT: usize = 64;
57
58pub struct SecureMemory {
60 ptr: NonNull<u8>,
62 size: usize,
64 data_len: usize,
66 locked: bool,
68 layout: Layout,
70}
71
72unsafe impl Send for SecureMemory {}
74unsafe impl Sync for SecureMemory {}
76
77pub struct SecureVec {
79 memory: SecureMemory,
81 len: usize,
83}
84
85pub struct SecureString {
87 vec: SecureVec,
89}
90
91pub struct SecureMemoryPool {
93 available: Mutex<VecDeque<SecureMemory>>,
95 total_size: usize,
97 chunk_size: usize,
99 stats: Mutex<PoolStats>,
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct PoolStats {
106 pub total_allocations: u64,
108 pub total_deallocations: u64,
110 pub active_allocations: u64,
112 pub pool_hits: u64,
114 pub pool_misses: u64,
116 pub total_bytes_allocated: u64,
118 pub current_bytes_in_use: u64,
120}
121
122#[derive(Debug, Clone)]
124pub enum SecureMemoryError {
125 AllocationFailed(String),
127 LockingFailed(String),
129 InvalidParameters(String),
131 PoolExhausted,
133 NotSupported(String),
135}
136
137impl std::fmt::Display for SecureMemoryError {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 SecureMemoryError::AllocationFailed(msg) => write!(f, "Allocation failed: {msg}"),
141 SecureMemoryError::LockingFailed(msg) => write!(f, "Memory locking failed: {msg}"),
142 SecureMemoryError::InvalidParameters(msg) => write!(f, "Invalid parameters: {msg}"),
143 SecureMemoryError::PoolExhausted => write!(f, "Secure memory pool exhausted"),
144 SecureMemoryError::NotSupported(msg) => write!(f, "Operation not supported: {msg}"),
145 }
146 }
147}
148
149impl std::error::Error for SecureMemoryError {}
150
151impl SecureMemory {
152 pub fn new(size: usize) -> Result<Self> {
154 if size == 0 {
155 return Err(P2PError::Io(std::io::Error::new(
156 std::io::ErrorKind::InvalidInput,
157 "Cannot allocate zero-sized memory",
158 )));
159 }
160
161 if size > MAX_SECURE_ALLOCATION {
162 return Err(P2PError::Io(std::io::Error::new(
163 std::io::ErrorKind::InvalidInput,
164 format!("Allocation size {size} exceeds maximum {MAX_SECURE_ALLOCATION}"),
165 )));
166 }
167
168 let aligned_size = (size + SECURE_ALIGNMENT - 1) & !(SECURE_ALIGNMENT - 1);
170
171 let layout = Layout::from_size_align(aligned_size, SECURE_ALIGNMENT).map_err(|e| {
173 P2PError::Io(std::io::Error::new(
174 std::io::ErrorKind::InvalidInput,
175 format!("Invalid layout: {e}"),
176 ))
177 })?;
178
179 let ptr = unsafe { alloc_zeroed(layout) };
181 if ptr.is_null() {
182 return Err(P2PError::Io(std::io::Error::new(
183 std::io::ErrorKind::OutOfMemory,
184 "Memory allocation failed",
185 )));
186 }
187
188 let ptr = NonNull::new(ptr).ok_or_else(|| {
189 P2PError::Io(std::io::Error::new(
190 std::io::ErrorKind::OutOfMemory,
191 "Null pointer returned from allocator",
192 ))
193 })?;
194
195 let mut memory = Self {
196 ptr,
197 size: aligned_size,
198 data_len: size,
199 locked: false,
200 layout,
201 };
202
203 if let Err(e) = memory.lock_memory() {
205 tracing::warn!("Failed to lock secure memory: {}", e);
206 }
207
208 Ok(memory)
209 }
210
211 pub fn from_slice(data: &[u8]) -> Result<Self> {
213 let mut memory = Self::new(data.len())?;
214 memory.as_mut_slice()[..data.len()].copy_from_slice(data);
215 Ok(memory)
216 }
217
218 pub fn len(&self) -> usize {
220 self.size
221 }
222
223 pub fn is_empty(&self) -> bool {
225 self.size == 0
226 }
227
228 pub fn as_slice(&self) -> &[u8] {
230 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.data_len) }
231 }
232
233 pub fn as_mut_slice(&mut self) -> &mut [u8] {
235 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.data_len) }
236 }
237
238 pub fn as_allocated_slice(&self) -> &[u8] {
240 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
241 }
242
243 pub fn as_allocated_mut_slice(&mut self) -> &mut [u8] {
245 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
246 }
247
248 pub fn constant_time_eq(&self, other: &SecureMemory) -> bool {
250 if self.data_len != other.data_len {
251 return false;
252 }
253
254 self.as_slice().ct_eq(other.as_slice()).into()
256 }
257
258 fn lock_memory(&mut self) -> Result<()> {
260 if self.locked {
261 return Ok(());
262 }
263
264 #[cfg(unix)]
265 {
266 let result = unsafe { mlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
267 if result != 0 {
268 return Err(P2PError::Io(std::io::Error::new(
269 std::io::ErrorKind::PermissionDenied,
270 "Failed to lock memory pages",
271 )));
272 }
273 }
274
275 #[cfg(windows)]
276 {
277 let result =
278 unsafe { VirtualLock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
279 if result == 0 {
280 return Err(P2PError::Io(std::io::Error::new(
281 std::io::ErrorKind::PermissionDenied,
282 "VirtualLock failed",
283 )));
284 }
285 }
286
287 #[cfg(not(any(unix, windows)))]
288 {
289 tracing::warn!("Memory locking not supported on this platform");
290 }
291
292 self.locked = true;
293 Ok(())
294 }
295
296 fn unlock_memory(&mut self) {
298 if !self.locked {
299 return;
300 }
301
302 #[cfg(unix)]
303 {
304 unsafe { munlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
305 }
306
307 #[cfg(windows)]
308 {
309 unsafe { VirtualUnlock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
310 }
311
312 self.locked = false;
313 }
314
315 pub fn zeroize(&mut self) {
317 unsafe {
318 ptr::write_volatile(self.ptr.as_ptr(), 0u8);
320
321 for i in 0..self.size {
323 ptr::write_volatile(self.ptr.as_ptr().add(i), 0u8);
324 }
325 }
326 }
327}
328
329impl Drop for SecureMemory {
330 fn drop(&mut self) {
331 self.zeroize();
333
334 self.unlock_memory();
336
337 unsafe {
339 dealloc(self.ptr.as_ptr(), self.layout);
340 }
341 }
342}
343
344impl SecureVec {
345 pub fn with_capacity(capacity: usize) -> Result<Self> {
347 let memory = SecureMemory::new(capacity)?;
348 Ok(Self { memory, len: 0 })
349 }
350
351 pub fn from_slice(data: &[u8]) -> Result<Self> {
353 let memory = SecureMemory::from_slice(data)?;
354 let len = data.len();
355 Ok(Self { memory, len })
356 }
357
358 pub fn len(&self) -> usize {
360 self.len
361 }
362
363 pub fn is_empty(&self) -> bool {
365 self.len == 0
366 }
367
368 pub fn capacity(&self) -> usize {
370 self.memory.len()
371 }
372
373 pub fn push(&mut self, value: u8) -> Result<()> {
375 if self.len >= self.capacity() {
376 return Err(P2PError::Io(std::io::Error::new(
377 std::io::ErrorKind::InvalidInput,
378 "SecureVec capacity exceeded",
379 )));
380 }
381
382 self.memory.as_allocated_mut_slice()[self.len] = value;
383 self.len += 1;
384 Ok(())
385 }
386
387 pub fn extend_from_slice(&mut self, data: &[u8]) -> Result<()> {
389 if self.len + data.len() > self.capacity() {
390 return Err(P2PError::Io(std::io::Error::new(
391 std::io::ErrorKind::InvalidInput,
392 "SecureVec capacity exceeded",
393 )));
394 }
395
396 self.memory.as_allocated_mut_slice()[self.len..self.len + data.len()].copy_from_slice(data);
397 self.len += data.len();
398 Ok(())
399 }
400
401 pub fn as_slice(&self) -> &[u8] {
403 &self.memory.as_slice()[..self.len]
404 }
405
406 pub fn clear(&mut self) {
408 self.memory.zeroize();
409 self.len = 0;
410 }
411}
412
413impl Deref for SecureVec {
414 type Target = [u8];
415
416 fn deref(&self) -> &Self::Target {
417 self.as_slice()
418 }
419}
420
421impl SecureString {
422 pub fn with_capacity(capacity: usize) -> Result<Self> {
424 let vec = SecureVec::with_capacity(capacity)?;
425 Ok(Self { vec })
426 }
427
428 pub fn from_plain_str(s: &str) -> Result<Self> {
430 let vec = SecureVec::from_slice(s.as_bytes())?;
431 Ok(Self { vec })
432 }
433
434 pub fn len(&self) -> usize {
436 self.vec.len()
437 }
438
439 pub fn is_empty(&self) -> bool {
441 self.vec.is_empty()
442 }
443
444 pub fn push(&mut self, ch: char) -> Result<()> {
446 let mut buffer = [0u8; 4];
447 let encoded = ch.encode_utf8(&mut buffer);
448 self.vec.extend_from_slice(encoded.as_bytes())
449 }
450
451 pub fn push_str(&mut self, s: &str) -> Result<()> {
453 self.vec.extend_from_slice(s.as_bytes())
454 }
455
456 pub fn as_str(&self) -> Result<&str> {
458 std::str::from_utf8(self.vec.as_slice()).map_err(|e| {
459 P2PError::Io(std::io::Error::new(
460 std::io::ErrorKind::InvalidData,
461 format!("Invalid UTF-8: {e}"),
462 ))
463 })
464 }
465
466 pub fn clear(&mut self) {
468 self.vec.clear();
469 }
470}
471
472impl fmt::Display for SecureString {
473 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474 match self.as_str() {
475 Ok(s) => write!(f, "{s}"),
476 Err(_) => write!(f, "<invalid UTF-8>"),
477 }
478 }
479}
480
481impl fmt::Debug for SecureString {
482 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
483 write!(f, "SecureString[{}]", self.len())
484 }
485}
486
487impl SecureMemoryPool {
488 pub fn new(total_size: usize, chunk_size: usize) -> Result<Self> {
490 if chunk_size > total_size {
491 return Err(P2PError::Io(std::io::Error::new(
492 std::io::ErrorKind::InvalidInput,
493 "Chunk size cannot exceed total size",
494 )));
495 }
496
497 let pool = Self {
498 available: Mutex::new(VecDeque::new()),
499 total_size,
500 chunk_size,
501 stats: Mutex::new(PoolStats::default()),
502 };
503
504 pool.preallocate_chunks()?;
506
507 Ok(pool)
508 }
509
510 pub fn default_pool() -> Result<Self> {
512 Self::new(DEFAULT_POOL_SIZE, 4096)
513 }
514
515 pub fn allocate(&self, size: usize) -> Result<SecureMemory> {
517 if size > self.chunk_size {
518 match self.stats.lock() {
520 Ok(mut stats) => {
521 stats.pool_misses += 1;
522 stats.total_allocations += 1;
523 stats.active_allocations += 1;
524 stats.total_bytes_allocated += size as u64;
525 stats.current_bytes_in_use += size as u64;
526 }
527 Err(e) => {
528 tracing::warn!("Memory pool stats mutex poisoned in allocate: {}", e);
529 }
530 }
531 return SecureMemory::new(size);
532 }
533
534 {
536 if let Ok(mut available) = self.available.lock()
537 && let Some(memory) = available.pop_front()
538 {
539 match self.stats.lock() {
540 Ok(mut stats) => {
541 stats.pool_hits += 1;
542 stats.total_allocations += 1;
543 stats.active_allocations += 1;
544 stats.current_bytes_in_use += memory.len() as u64;
545 }
546 Err(e) => {
547 tracing::warn!("Memory pool stats mutex poisoned in allocate: {}", e);
548 }
549 }
550 return Ok(memory);
551 }
552 }
553
554 match self.stats.lock() {
556 Ok(mut stats) => {
557 stats.pool_misses += 1;
558 stats.total_allocations += 1;
559 stats.active_allocations += 1;
560 stats.total_bytes_allocated += self.chunk_size as u64;
561 stats.current_bytes_in_use += self.chunk_size as u64;
562 }
563 Err(e) => {
564 tracing::warn!("Memory pool stats mutex poisoned in allocate: {}", e);
565 }
566 }
567
568 SecureMemory::new(self.chunk_size)
569 }
570
571 pub fn deallocate(&self, mut memory: SecureMemory) {
573 memory.zeroize();
575
576 let memory_size = memory.len();
577
578 if memory_size == self.chunk_size {
579 match self.available.lock() {
581 Ok(mut available) => {
582 available.push_back(memory);
583 }
584 Err(e) => {
585 tracing::warn!("Memory pool available mutex poisoned in deallocate: {}", e);
586 }
587 }
588 }
589 match self.stats.lock() {
592 Ok(mut stats) => {
593 stats.total_deallocations += 1;
594 stats.active_allocations -= 1;
595 stats.current_bytes_in_use -= memory_size as u64;
596 }
597 Err(e) => {
598 tracing::warn!("Memory pool stats mutex poisoned in deallocate: {}", e);
599 }
600 }
601 }
602
603 pub fn stats(&self) -> PoolStats {
605 match self.stats.lock() {
606 Ok(s) => s.clone(),
607 Err(e) => {
608 tracing::warn!("Memory pool stats mutex poisoned: {}", e);
609 PoolStats::default()
610 }
611 }
612 }
613
614 fn preallocate_chunks(&self) -> Result<()> {
616 let num_chunks = self.total_size / self.chunk_size;
617 match self.available.lock() {
618 Ok(mut available) => {
619 for _ in 0..num_chunks {
620 let memory = SecureMemory::new(self.chunk_size)?;
621 available.push_back(memory);
622 }
623 }
624 Err(e) => {
625 tracing::warn!(
626 "Memory pool available mutex poisoned in preallocate_chunks: {}",
627 e
628 );
629 }
630 }
631
632 Ok(())
633 }
634}
635
636static GLOBAL_POOL: std::sync::OnceLock<Result<SecureMemoryPool>> = std::sync::OnceLock::new();
638
639pub fn global_secure_pool() -> &'static SecureMemoryPool {
641 let result = GLOBAL_POOL.get_or_init(SecureMemoryPool::default_pool);
642 match result {
643 Ok(pool) => pool,
644 Err(_) => match SecureMemoryPool::new(DEFAULT_POOL_SIZE, 4096) {
645 Ok(pool) => {
646 let _ = GLOBAL_POOL.set(Ok(pool));
647 if let Some(Ok(pool)) = GLOBAL_POOL.get() {
648 pool
649 } else {
650 static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
652 once_cell::sync::OnceCell::new();
653 FALLBACK.get_or_init(|| SecureMemoryPool {
654 available: Mutex::new(VecDeque::new()),
655 total_size: DEFAULT_POOL_SIZE,
656 chunk_size: 4096,
657 stats: Mutex::new(PoolStats::default()),
658 })
659 }
660 }
661 Err(_) => {
662 static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
664 once_cell::sync::OnceCell::new();
665 FALLBACK.get_or_init(|| SecureMemoryPool {
666 available: Mutex::new(VecDeque::new()),
667 total_size: DEFAULT_POOL_SIZE,
668 chunk_size: 4096,
669 stats: Mutex::new(PoolStats::default()),
670 })
671 }
672 },
673 }
674}
675
676pub fn allocate_secure(size: usize) -> Result<SecureMemory> {
678 global_secure_pool().allocate(size)
679}
680
681pub fn secure_vec_with_capacity(capacity: usize) -> Result<SecureVec> {
683 let memory = global_secure_pool().allocate(capacity)?;
684 Ok(SecureVec { memory, len: 0 })
685}
686
687pub fn secure_string_with_capacity(capacity: usize) -> Result<SecureString> {
689 let vec = secure_vec_with_capacity(capacity)?;
690 Ok(SecureString { vec })
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 #[test]
698 fn test_secure_memory_basic() {
699 let mut memory = SecureMemory::new(1024).unwrap();
700
701 assert_eq!(memory.len(), 1024);
703 assert!(!memory.is_empty());
704
705 memory.as_mut_slice()[0] = 42;
707 assert_eq!(memory.as_slice()[0], 42);
708
709 memory.zeroize();
711 assert_eq!(memory.as_slice()[0], 0);
712 }
713
714 #[test]
715 fn test_secure_memory_constant_time_comparison() {
716 let memory1 = SecureMemory::from_slice(b"hello").unwrap();
717 let memory2 = SecureMemory::from_slice(b"hello").unwrap();
718 let memory3 = SecureMemory::from_slice(b"world").unwrap();
719
720 assert!(memory1.constant_time_eq(&memory2));
721 assert!(!memory1.constant_time_eq(&memory3));
722 }
723
724 #[test]
725 fn test_constant_time_eq_comprehensive() {
726 let key1 = SecureMemory::from_slice(b"supersecretkey123456").unwrap();
728 let key2 = SecureMemory::from_slice(b"supersecretkey123456").unwrap();
729 assert!(key1.constant_time_eq(&key2), "Equal keys should match");
730
731 let key_short = SecureMemory::from_slice(b"short").unwrap();
733 let key_long = SecureMemory::from_slice(b"muchlongerkey").unwrap();
734 assert!(
735 !key_short.constant_time_eq(&key_long),
736 "Different length keys should not match"
737 );
738
739 let key_first_diff1 = SecureMemory::from_slice(b"Aello123456789012").unwrap();
741 let key_first_diff2 = SecureMemory::from_slice(b"Bello123456789012").unwrap();
742 assert!(
743 !key_first_diff1.constant_time_eq(&key_first_diff2),
744 "Keys differing in first byte should not match"
745 );
746
747 let key_last_diff1 = SecureMemory::from_slice(b"hello12345678901A").unwrap();
749 let key_last_diff2 = SecureMemory::from_slice(b"hello12345678901B").unwrap();
750 assert!(
751 !key_last_diff1.constant_time_eq(&key_last_diff2),
752 "Keys differing in last byte should not match"
753 );
754
755 let key_mid_diff1 = SecureMemory::from_slice(b"hello1X3456789012").unwrap();
757 let key_mid_diff2 = SecureMemory::from_slice(b"hello1Y3456789012").unwrap();
758 assert!(
759 !key_mid_diff1.constant_time_eq(&key_mid_diff2),
760 "Keys differing in middle byte should not match"
761 );
762
763 let one_byte = SecureMemory::from_slice(b"a").unwrap();
765 let two_bytes = SecureMemory::from_slice(b"ab").unwrap();
766 assert!(
767 !one_byte.constant_time_eq(&two_bytes),
768 "Keys with length difference should not match"
769 );
770
771 let single1 = SecureMemory::from_slice(b"A").unwrap();
773 let single2 = SecureMemory::from_slice(b"A").unwrap();
774 let single3 = SecureMemory::from_slice(b"B").unwrap();
775 assert!(
776 single1.constant_time_eq(&single2),
777 "Equal single byte keys should match"
778 );
779 assert!(
780 !single1.constant_time_eq(&single3),
781 "Different single byte keys should not match"
782 );
783
784 let large1 = SecureMemory::from_slice(b"12345678901234567890123456789012").unwrap();
786 let large2 = SecureMemory::from_slice(b"12345678901234567890123456789012").unwrap();
787 let large3 = SecureMemory::from_slice(b"12345678901234567890123456789013").unwrap();
788 assert!(
789 large1.constant_time_eq(&large2),
790 "Equal large keys should match"
791 );
792 assert!(
793 !large1.constant_time_eq(&large3),
794 "Different large keys should not match"
795 );
796 }
797
798 #[test]
799 fn test_secure_vec() {
800 let mut vec = SecureVec::with_capacity(100).unwrap();
801
802 vec.push(1).unwrap();
804 vec.push(2).unwrap();
805 vec.extend_from_slice(&[3, 4, 5]).unwrap();
806
807 assert_eq!(vec.len(), 5);
808 assert_eq!(vec.as_slice(), &[1, 2, 3, 4, 5]);
809
810 vec.clear();
812 assert_eq!(vec.len(), 0);
813 assert!(vec.is_empty());
814 }
815
816 #[test]
817 fn test_secure_string() {
818 let mut string = SecureString::with_capacity(100).unwrap();
819
820 string.push('H').unwrap();
822 string.push_str("ello").unwrap();
823
824 assert_eq!(string.as_str().unwrap(), "Hello");
825 assert_eq!(string.len(), 5);
826
827 string.clear();
829 assert_eq!(string.len(), 0);
830 assert!(string.is_empty());
831 }
832
833 #[test]
834 fn test_secure_memory_pool() {
835 let pool = SecureMemoryPool::new(8192, 1024).unwrap();
836
837 let memory1 = pool.allocate(512).unwrap();
839 let memory2 = pool.allocate(1024).unwrap();
840
841 let stats = pool.stats();
843 assert_eq!(stats.total_allocations, 2);
844 assert_eq!(stats.active_allocations, 2);
845
846 pool.deallocate(memory1);
848 pool.deallocate(memory2);
849
850 let stats = pool.stats();
851 assert_eq!(stats.total_deallocations, 2);
852 assert_eq!(stats.active_allocations, 0);
853 }
854
855 #[test]
856 fn test_global_pool() {
857 let memory = allocate_secure(256).unwrap();
858 println!(
859 "allocate_secure(256) returned memory.len() = {}",
860 memory.len()
861 );
862 assert_eq!(memory.len(), 4096); let vec = secure_vec_with_capacity(128).unwrap();
866 println!(
867 "secure_vec_with_capacity(128) returned vec.capacity() = {}",
868 vec.capacity()
869 );
870 assert_eq!(vec.capacity(), 4096); let string = secure_string_with_capacity(64).unwrap();
873 println!(
874 "secure_string_with_capacity(64) returned string.vec.capacity() = {}",
875 string.vec.capacity()
876 );
877 assert_eq!(string.vec.capacity(), 4096); }
879}