saorsa_core/
secure_memory.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Secure Memory Management for Cryptographic Operations
15//!
16//! This module provides memory-protected storage for cryptographic keys and sensitive data.
17//! All allocations are automatically zeroized on drop and protected against memory dumps.
18//!
19//! ## Security Features
20//! - Automatic zeroization on drop (prevents key recovery)
21//! - Memory locking to prevent swapping to disk
22//! - Protected allocation regions
23//! - Constant-time comparison operations
24//! - Guard pages to detect buffer overflows
25//!
26//! ## Performance Features
27//! - Pool-based allocation to reduce fragmentation
28//! - Batch allocation for multiple keys
29//! - Efficient reuse of protected memory regions
30//! - Minimal overhead for secure operations
31
32#![allow(unsafe_code)] // Required for secure memory operations: mlock, memory zeroing, and protected allocation
33
34use crate::{P2PError, Result};
35use std::alloc::{Layout, alloc_zeroed, dealloc};
36use std::collections::VecDeque;
37use std::fmt;
38use std::ops::Deref;
39use std::ptr::{self, NonNull};
40use std::sync::Mutex;
41
42#[cfg(unix)]
43use libc::{mlock, munlock};
44
45#[cfg(windows)]
46use winapi::um::memoryapi::{VirtualLock, VirtualUnlock};
47
48/// Maximum size for a single secure allocation (64KB)
49const MAX_SECURE_ALLOCATION: usize = 65536;
50
51/// Default size for the secure memory pool (1MB)
52const DEFAULT_POOL_SIZE: usize = 1024 * 1024;
53
54/// Alignment requirement for secure allocations
55const SECURE_ALIGNMENT: usize = 64;
56
57/// Secure memory container that automatically zeroizes on drop
58pub struct SecureMemory {
59    /// Pointer to the allocated memory
60    ptr: NonNull<u8>,
61    /// Size of the allocation
62    size: usize,
63    /// Whether the memory is locked (cannot be swapped)
64    locked: bool,
65    /// Layout used for allocation
66    layout: Layout,
67}
68
69// Safety: SecureMemory is safe to send between threads as it owns its memory
70unsafe impl Send for SecureMemory {}
71// Safety: SecureMemory is safe to share between threads with proper synchronization
72unsafe impl Sync for SecureMemory {}
73
74/// Secure vector with automatic zeroization
75pub struct SecureVec {
76    /// Underlying secure memory
77    memory: SecureMemory,
78    /// Current length of the vector
79    len: usize,
80}
81
82/// Secure string with automatic zeroization
83pub struct SecureString {
84    /// Underlying secure vector
85    vec: SecureVec,
86}
87
88/// Pool for managing secure memory allocations
89pub struct SecureMemoryPool {
90    /// Available memory chunks
91    available: Mutex<VecDeque<SecureMemory>>,
92    /// Total pool size
93    total_size: usize,
94    /// Chunk size for allocations
95    chunk_size: usize,
96    /// Statistics
97    stats: Mutex<PoolStats>,
98}
99
100/// Statistics for secure memory pool
101#[derive(Debug, Clone, Default)]
102pub struct PoolStats {
103    /// Total allocations made
104    pub total_allocations: u64,
105    /// Total deallocations
106    pub total_deallocations: u64,
107    /// Current active allocations
108    pub active_allocations: u64,
109    /// Pool hits (reused memory)
110    pub pool_hits: u64,
111    /// Pool misses (new allocations)
112    pub pool_misses: u64,
113    /// Total bytes allocated
114    pub total_bytes_allocated: u64,
115    /// Current bytes in use
116    pub current_bytes_in_use: u64,
117}
118
119/// Error types for secure memory operations
120#[derive(Debug, Clone)]
121pub enum SecureMemoryError {
122    /// Allocation failed
123    AllocationFailed(String),
124    /// Memory locking failed
125    LockingFailed(String),
126    /// Invalid size or alignment
127    InvalidParameters(String),
128    /// Pool exhausted
129    PoolExhausted,
130    /// Operation not supported on this platform
131    NotSupported(String),
132}
133
134impl std::fmt::Display for SecureMemoryError {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            SecureMemoryError::AllocationFailed(msg) => write!(f, "Allocation failed: {msg}"),
138            SecureMemoryError::LockingFailed(msg) => write!(f, "Memory locking failed: {msg}"),
139            SecureMemoryError::InvalidParameters(msg) => write!(f, "Invalid parameters: {msg}"),
140            SecureMemoryError::PoolExhausted => write!(f, "Secure memory pool exhausted"),
141            SecureMemoryError::NotSupported(msg) => write!(f, "Operation not supported: {msg}"),
142        }
143    }
144}
145
146impl std::error::Error for SecureMemoryError {}
147
148impl SecureMemory {
149    /// Allocate secure memory with the given size
150    pub fn new(size: usize) -> Result<Self> {
151        if size == 0 {
152            return Err(P2PError::Io(std::io::Error::new(
153                std::io::ErrorKind::InvalidInput,
154                "Cannot allocate zero-sized memory",
155            )));
156        }
157
158        if size > MAX_SECURE_ALLOCATION {
159            return Err(P2PError::Io(std::io::Error::new(
160                std::io::ErrorKind::InvalidInput,
161                format!("Allocation size {size} exceeds maximum {MAX_SECURE_ALLOCATION}"),
162            )));
163        }
164
165        // Align size to secure alignment boundary
166        let aligned_size = (size + SECURE_ALIGNMENT - 1) & !(SECURE_ALIGNMENT - 1);
167
168        // Create layout for allocation
169        let layout = Layout::from_size_align(aligned_size, SECURE_ALIGNMENT).map_err(|e| {
170            P2PError::Io(std::io::Error::new(
171                std::io::ErrorKind::InvalidInput,
172                format!("Invalid layout: {e}"),
173            ))
174        })?;
175
176        // Allocate zeroed memory
177        let ptr = unsafe { alloc_zeroed(layout) };
178        if ptr.is_null() {
179            return Err(P2PError::Io(std::io::Error::new(
180                std::io::ErrorKind::OutOfMemory,
181                "Memory allocation failed",
182            )));
183        }
184
185        let ptr = NonNull::new(ptr).ok_or_else(|| {
186            P2PError::Io(std::io::Error::new(
187                std::io::ErrorKind::OutOfMemory,
188                "Null pointer returned from allocator",
189            ))
190        })?;
191
192        let mut memory = Self {
193            ptr,
194            size: aligned_size,
195            locked: false,
196            layout,
197        };
198
199        // Attempt to lock the memory
200        if let Err(e) = memory.lock_memory() {
201            tracing::warn!("Failed to lock secure memory: {}", e);
202        }
203
204        Ok(memory)
205    }
206
207    /// Create secure memory from existing data (data is copied and source should be zeroized)
208    pub fn from_slice(data: &[u8]) -> Result<Self> {
209        let mut memory = Self::new(data.len())?;
210        memory.as_mut_slice().copy_from_slice(data);
211        Ok(memory)
212    }
213
214    /// Get the size of the allocated memory
215    pub fn len(&self) -> usize {
216        self.size
217    }
218
219    /// Check if the memory is empty
220    pub fn is_empty(&self) -> bool {
221        self.size == 0
222    }
223
224    /// Get a slice view of the memory
225    pub fn as_slice(&self) -> &[u8] {
226        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
227    }
228
229    /// Get a mutable slice view of the memory
230    pub fn as_mut_slice(&mut self) -> &mut [u8] {
231        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
232    }
233
234    /// Compare two secure memory regions in constant time
235    pub fn constant_time_eq(&self, other: &SecureMemory) -> bool {
236        if self.size != other.size {
237            return false;
238        }
239
240        let a = self.as_slice();
241        let b = other.as_slice();
242
243        // Constant-time comparison
244        let mut result = 0u8;
245        for i in 0..self.size {
246            result |= a[i] ^ b[i];
247        }
248
249        result == 0
250    }
251
252    /// Lock memory to prevent it from being swapped to disk
253    fn lock_memory(&mut self) -> Result<()> {
254        if self.locked {
255            return Ok(());
256        }
257
258        #[cfg(unix)]
259        {
260            let result = unsafe { mlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
261            if result != 0 {
262                return Err(P2PError::Io(std::io::Error::new(
263                    std::io::ErrorKind::PermissionDenied,
264                    "Failed to lock memory pages",
265                )));
266            }
267        }
268
269        #[cfg(windows)]
270        {
271            let result =
272                unsafe { VirtualLock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
273            if result == 0 {
274                return Err(P2PError::Io(std::io::Error::new(
275                    std::io::ErrorKind::PermissionDenied,
276                    "VirtualLock failed",
277                )));
278            }
279        }
280
281        #[cfg(not(any(unix, windows)))]
282        {
283            tracing::warn!("Memory locking not supported on this platform");
284        }
285
286        self.locked = true;
287        Ok(())
288    }
289
290    /// Unlock memory (called automatically on drop)
291    fn unlock_memory(&mut self) {
292        if !self.locked {
293            return;
294        }
295
296        #[cfg(unix)]
297        {
298            unsafe { munlock(self.ptr.as_ptr() as *const libc::c_void, self.size) };
299        }
300
301        #[cfg(windows)]
302        {
303            unsafe { VirtualUnlock(self.ptr.as_ptr() as *mut winapi::ctypes::c_void, self.size) };
304        }
305
306        self.locked = false;
307    }
308
309    /// Securely zeroize the memory
310    pub fn zeroize(&mut self) {
311        unsafe {
312            // Use volatile write to prevent compiler optimization
313            ptr::write_volatile(self.ptr.as_ptr(), 0u8);
314
315            // Zeroize the entire allocation
316            for i in 0..self.size {
317                ptr::write_volatile(self.ptr.as_ptr().add(i), 0u8);
318            }
319        }
320    }
321}
322
323impl Drop for SecureMemory {
324    fn drop(&mut self) {
325        // Zeroize memory before deallocation
326        self.zeroize();
327
328        // Unlock memory
329        self.unlock_memory();
330
331        // Deallocate memory
332        unsafe {
333            dealloc(self.ptr.as_ptr(), self.layout);
334        }
335    }
336}
337
338impl SecureVec {
339    /// Create a new secure vector with the given capacity
340    pub fn with_capacity(capacity: usize) -> Result<Self> {
341        let memory = SecureMemory::new(capacity)?;
342        Ok(Self { memory, len: 0 })
343    }
344
345    /// Create a secure vector from existing data
346    pub fn from_slice(data: &[u8]) -> Result<Self> {
347        let memory = SecureMemory::from_slice(data)?;
348        let len = data.len();
349        Ok(Self { memory, len })
350    }
351
352    /// Get the length of the vector
353    pub fn len(&self) -> usize {
354        self.len
355    }
356
357    /// Check if the vector is empty
358    pub fn is_empty(&self) -> bool {
359        self.len == 0
360    }
361
362    /// Get the capacity of the vector
363    pub fn capacity(&self) -> usize {
364        self.memory.len()
365    }
366
367    /// Push a byte to the vector
368    pub fn push(&mut self, value: u8) -> Result<()> {
369        if self.len >= self.capacity() {
370            return Err(P2PError::Io(std::io::Error::new(
371                std::io::ErrorKind::InvalidInput,
372                "SecureVec capacity exceeded",
373            )));
374        }
375
376        self.memory.as_mut_slice()[self.len] = value;
377        self.len += 1;
378        Ok(())
379    }
380
381    /// Extend the vector with data from a slice
382    pub fn extend_from_slice(&mut self, data: &[u8]) -> Result<()> {
383        if self.len + data.len() > self.capacity() {
384            return Err(P2PError::Io(std::io::Error::new(
385                std::io::ErrorKind::InvalidInput,
386                "SecureVec capacity exceeded",
387            )));
388        }
389
390        self.memory.as_mut_slice()[self.len..self.len + data.len()].copy_from_slice(data);
391        self.len += data.len();
392        Ok(())
393    }
394
395    /// Get a slice of the vector's contents
396    pub fn as_slice(&self) -> &[u8] {
397        &self.memory.as_slice()[..self.len]
398    }
399
400    /// Clear the vector (zeroizes the data)
401    pub fn clear(&mut self) {
402        self.memory.zeroize();
403        self.len = 0;
404    }
405}
406
407impl Deref for SecureVec {
408    type Target = [u8];
409
410    fn deref(&self) -> &Self::Target {
411        self.as_slice()
412    }
413}
414
415impl SecureString {
416    /// Create a new secure string with the given capacity
417    pub fn with_capacity(capacity: usize) -> Result<Self> {
418        let vec = SecureVec::with_capacity(capacity)?;
419        Ok(Self { vec })
420    }
421
422    /// Create a secure string from a regular string
423    pub fn from_plain_str(s: &str) -> Result<Self> {
424        let vec = SecureVec::from_slice(s.as_bytes())?;
425        Ok(Self { vec })
426    }
427
428    /// Get the length of the string
429    pub fn len(&self) -> usize {
430        self.vec.len()
431    }
432
433    /// Check if the string is empty
434    pub fn is_empty(&self) -> bool {
435        self.vec.is_empty()
436    }
437
438    /// Push a character to the string
439    pub fn push(&mut self, ch: char) -> Result<()> {
440        let mut buffer = [0u8; 4];
441        let encoded = ch.encode_utf8(&mut buffer);
442        self.vec.extend_from_slice(encoded.as_bytes())
443    }
444
445    /// Push a string slice to the string
446    pub fn push_str(&mut self, s: &str) -> Result<()> {
447        self.vec.extend_from_slice(s.as_bytes())
448    }
449
450    /// Get the string as a str slice
451    pub fn as_str(&self) -> Result<&str> {
452        std::str::from_utf8(self.vec.as_slice()).map_err(|e| {
453            P2PError::Io(std::io::Error::new(
454                std::io::ErrorKind::InvalidData,
455                format!("Invalid UTF-8: {e}"),
456            ))
457        })
458    }
459
460    /// Clear the string (zeroizes the data)
461    pub fn clear(&mut self) {
462        self.vec.clear();
463    }
464}
465
466impl fmt::Display for SecureString {
467    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468        match self.as_str() {
469            Ok(s) => write!(f, "{s}"),
470            Err(_) => write!(f, "<invalid UTF-8>"),
471        }
472    }
473}
474
475impl fmt::Debug for SecureString {
476    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477        write!(f, "SecureString[{}]", self.len())
478    }
479}
480
481impl SecureMemoryPool {
482    /// Create a new secure memory pool
483    pub fn new(total_size: usize, chunk_size: usize) -> Result<Self> {
484        if chunk_size > total_size {
485            return Err(P2PError::Io(std::io::Error::new(
486                std::io::ErrorKind::InvalidInput,
487                "Chunk size cannot exceed total size",
488            )));
489        }
490
491        let pool = Self {
492            available: Mutex::new(VecDeque::new()),
493            total_size,
494            chunk_size,
495            stats: Mutex::new(PoolStats::default()),
496        };
497
498        // Pre-allocate chunks
499        pool.preallocate_chunks()?;
500
501        Ok(pool)
502    }
503
504    /// Create a default secure memory pool
505    pub fn default_pool() -> Result<Self> {
506        Self::new(DEFAULT_POOL_SIZE, 4096)
507    }
508
509    /// Allocate memory from the pool
510    pub fn allocate(&self, size: usize) -> Result<SecureMemory> {
511        if size > self.chunk_size {
512            // Large allocation - allocate directly
513            if let Ok(mut stats) = self.stats.lock() {
514                stats.pool_misses += 1;
515                stats.total_allocations += 1;
516                stats.active_allocations += 1;
517                stats.total_bytes_allocated += size as u64;
518                stats.current_bytes_in_use += size as u64;
519            }
520            return SecureMemory::new(size);
521        }
522
523        // Try to get from pool
524        {
525            if let Ok(mut available) = self.available.lock()
526                && let Some(memory) = available.pop_front()
527            {
528                if let Ok(mut stats) = self.stats.lock() {
529                    stats.pool_hits += 1;
530                    stats.total_allocations += 1;
531                    stats.active_allocations += 1;
532                    stats.current_bytes_in_use += memory.len() as u64;
533                }
534                return Ok(memory);
535            }
536        }
537
538        // Pool empty - allocate new chunk
539        if let Ok(mut stats) = self.stats.lock() {
540            stats.pool_misses += 1;
541            stats.total_allocations += 1;
542            stats.active_allocations += 1;
543            stats.total_bytes_allocated += self.chunk_size as u64;
544            stats.current_bytes_in_use += self.chunk_size as u64;
545        }
546
547        SecureMemory::new(self.chunk_size)
548    }
549
550    /// Return memory to the pool
551    pub fn deallocate(&self, mut memory: SecureMemory) {
552        // Zeroize before returning to pool
553        memory.zeroize();
554
555        let memory_size = memory.len();
556
557        if memory_size == self.chunk_size {
558            // Return to pool
559            if let Ok(mut available) = self.available.lock() {
560                available.push_back(memory);
561            }
562        }
563        // Large allocations are dropped automatically
564
565        if let Ok(mut stats) = self.stats.lock() {
566            stats.total_deallocations += 1;
567            stats.active_allocations -= 1;
568            stats.current_bytes_in_use -= memory_size as u64;
569        }
570    }
571
572    /// Get pool statistics
573    pub fn stats(&self) -> PoolStats {
574        self.stats.lock().map(|s| s.clone()).unwrap_or_default()
575    }
576
577    /// Pre-allocate chunks for the pool
578    fn preallocate_chunks(&self) -> Result<()> {
579        let num_chunks = self.total_size / self.chunk_size;
580        if let Ok(mut available) = self.available.lock() {
581            for _ in 0..num_chunks {
582                let memory = SecureMemory::new(self.chunk_size)?;
583                available.push_back(memory);
584            }
585        }
586
587        Ok(())
588    }
589}
590
591/// Global secure memory pool instance
592static GLOBAL_POOL: std::sync::OnceLock<Result<SecureMemoryPool>> = std::sync::OnceLock::new();
593
594/// Get the global secure memory pool
595pub fn global_secure_pool() -> &'static SecureMemoryPool {
596    let result = GLOBAL_POOL.get_or_init(SecureMemoryPool::default_pool);
597    match result {
598        Ok(pool) => pool,
599        Err(_) => match SecureMemoryPool::new(DEFAULT_POOL_SIZE, 4096) {
600            Ok(pool) => {
601                let _ = GLOBAL_POOL.set(Ok(pool));
602                if let Some(Ok(pool)) = GLOBAL_POOL.get() {
603                    pool
604                } else {
605                    // fallback to a static default
606                    static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
607                        once_cell::sync::OnceCell::new();
608                    FALLBACK.get_or_init(|| SecureMemoryPool {
609                        available: Mutex::new(VecDeque::new()),
610                        total_size: DEFAULT_POOL_SIZE,
611                        chunk_size: 4096,
612                        stats: Mutex::new(PoolStats::default()),
613                    })
614                }
615            }
616            Err(_) => {
617                // Provide minimal fallback rather than panic
618                static FALLBACK: once_cell::sync::OnceCell<SecureMemoryPool> =
619                    once_cell::sync::OnceCell::new();
620                FALLBACK.get_or_init(|| SecureMemoryPool {
621                    available: Mutex::new(VecDeque::new()),
622                    total_size: DEFAULT_POOL_SIZE,
623                    chunk_size: 4096,
624                    stats: Mutex::new(PoolStats::default()),
625                })
626            }
627        },
628    }
629}
630
631/// Convenience function to allocate secure memory from global pool
632pub fn allocate_secure(size: usize) -> Result<SecureMemory> {
633    global_secure_pool().allocate(size)
634}
635
636/// Convenience function to create a secure vector from global pool
637pub fn secure_vec_with_capacity(capacity: usize) -> Result<SecureVec> {
638    let memory = global_secure_pool().allocate(capacity)?;
639    Ok(SecureVec { memory, len: 0 })
640}
641
642/// Convenience function to create a secure string from global pool
643pub fn secure_string_with_capacity(capacity: usize) -> Result<SecureString> {
644    let vec = secure_vec_with_capacity(capacity)?;
645    Ok(SecureString { vec })
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    #[test]
653    fn test_secure_memory_basic() {
654        let mut memory = SecureMemory::new(1024).unwrap();
655
656        // Test basic operations
657        assert_eq!(memory.len(), 1024);
658        assert!(!memory.is_empty());
659
660        // Test writing and reading
661        memory.as_mut_slice()[0] = 42;
662        assert_eq!(memory.as_slice()[0], 42);
663
664        // Test zeroization
665        memory.zeroize();
666        assert_eq!(memory.as_slice()[0], 0);
667    }
668
669    #[test]
670    fn test_secure_memory_constant_time_comparison() {
671        let memory1 = SecureMemory::from_slice(b"hello").unwrap();
672        let memory2 = SecureMemory::from_slice(b"hello").unwrap();
673        let memory3 = SecureMemory::from_slice(b"world").unwrap();
674
675        assert!(memory1.constant_time_eq(&memory2));
676        assert!(!memory1.constant_time_eq(&memory3));
677    }
678
679    #[test]
680    fn test_secure_vec() {
681        let mut vec = SecureVec::with_capacity(100).unwrap();
682
683        // Test basic operations
684        vec.push(1).unwrap();
685        vec.push(2).unwrap();
686        vec.extend_from_slice(&[3, 4, 5]).unwrap();
687
688        assert_eq!(vec.len(), 5);
689        assert_eq!(vec.as_slice(), &[1, 2, 3, 4, 5]);
690
691        // Test clear
692        vec.clear();
693        assert_eq!(vec.len(), 0);
694        assert!(vec.is_empty());
695    }
696
697    #[test]
698    fn test_secure_string() {
699        let mut string = SecureString::with_capacity(100).unwrap();
700
701        // Test basic operations
702        string.push('H').unwrap();
703        string.push_str("ello").unwrap();
704
705        assert_eq!(string.as_str().unwrap(), "Hello");
706        assert_eq!(string.len(), 5);
707
708        // Test clear
709        string.clear();
710        assert_eq!(string.len(), 0);
711        assert!(string.is_empty());
712    }
713
714    #[test]
715    fn test_secure_memory_pool() {
716        let pool = SecureMemoryPool::new(8192, 1024).unwrap();
717
718        // Test allocation
719        let memory1 = pool.allocate(512).unwrap();
720        let memory2 = pool.allocate(1024).unwrap();
721
722        // Check stats
723        let stats = pool.stats();
724        assert_eq!(stats.total_allocations, 2);
725        assert_eq!(stats.active_allocations, 2);
726
727        // Test deallocation
728        pool.deallocate(memory1);
729        pool.deallocate(memory2);
730
731        let stats = pool.stats();
732        assert_eq!(stats.total_deallocations, 2);
733        assert_eq!(stats.active_allocations, 0);
734    }
735
736    #[test]
737    fn test_global_pool() {
738        let memory = allocate_secure(256).unwrap();
739        assert_eq!(memory.len(), 256);
740
741        let vec = secure_vec_with_capacity(128).unwrap();
742        assert_eq!(vec.capacity(), 128);
743
744        let string = secure_string_with_capacity(64).unwrap();
745        assert_eq!(string.vec.capacity(), 64);
746    }
747}