Skip to main content

zk_nalloc/
bump.rs

1//! Core bump allocator for nalloc.
2//!
3//! A bump allocator is the fastest possible allocator: it simply increments
4//! a pointer. This module provides a thread-safe, atomic bump allocator
5//! optimized for ZK prover workloads with fallback support.
6
7#[cfg(feature = "fallback")]
8use std::alloc::{GlobalAlloc, Layout, System};
9use std::ptr::NonNull;
10use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
11
12use crate::config::SECURE_WIPE_PATTERN;
13
14/// A fast, lock-free bump allocator with fallback support.
15///
16/// Thread-safety is achieved via atomic compare-and-swap on the cursor.
17/// This allows multiple threads to allocate concurrently without locks,
18/// though there may be occasional retries on contention.
19///
20/// When the arena is exhausted and the `fallback` feature is enabled,
21/// allocations fall back to the system allocator.
22pub struct BumpAlloc {
23    /// Base pointer of the memory region (never changes after init).
24    base: NonNull<u8>,
25    /// End pointer of the memory region (never changes after init).
26    limit: NonNull<u8>,
27    /// Current allocation cursor (atomically updated).
28    cursor: AtomicUsize,
29    /// Tracks whether the arena has been recycled (reset after use).
30    /// Used to optimize zero-initialization in WitnessArena.
31    is_recycled: AtomicBool,
32    /// Counter for fallback allocations (for monitoring).
33    #[cfg(feature = "fallback")]
34    fallback_count: AtomicUsize,
35    /// Total bytes allocated via fallback.
36    #[cfg(feature = "fallback")]
37    fallback_bytes: AtomicUsize,
38}
39
40impl BumpAlloc {
41    /// Create a new bump allocator from a raw memory block.
42    ///
43    /// # Safety
44    /// The memory block `[base, base+size)` must be valid and writable.
45    #[inline]
46    pub unsafe fn new(base: *mut u8, size: usize) -> Self {
47        debug_assert!(!base.is_null());
48        debug_assert!(size > 0);
49
50        let base_nn = NonNull::new_unchecked(base);
51        let limit_nn = NonNull::new_unchecked(base.add(size));
52
53        Self {
54            base: base_nn,
55            limit: limit_nn,
56            cursor: AtomicUsize::new(base as usize),
57            is_recycled: AtomicBool::new(false),
58            #[cfg(feature = "fallback")]
59            fallback_count: AtomicUsize::new(0),
60            #[cfg(feature = "fallback")]
61            fallback_bytes: AtomicUsize::new(0),
62        }
63    }
64
65    /// Get the base pointer of this allocator.
66    #[inline]
67    pub fn base_ptr(&self) -> *mut u8 {
68        self.base.as_ptr()
69    }
70
71    /// Allocate memory with the given size and alignment.
72    ///
73    /// Returns a null pointer if there is not enough space and fallback is disabled.
74    /// With the `fallback` feature, falls back to system allocator.
75    #[inline(always)]
76    pub fn alloc(&self, size: usize, align: usize) -> *mut u8 {
77        // Runtime validation (Issue #6): prevent memory corruption from invalid inputs
78        if size == 0 || align == 0 || !align.is_power_of_two() {
79            return std::ptr::null_mut();
80        }
81
82        loop {
83            let current = self.cursor.load(Ordering::Relaxed);
84
85            // Issue #7: Use checked arithmetic to prevent integer overflow
86            let aligned = match current.checked_add(align - 1) {
87                Some(v) => v & !(align - 1),
88                None => return self.handle_exhaustion(size, align),
89            };
90            let next = match aligned.checked_add(size) {
91                Some(v) => v,
92                None => return self.handle_exhaustion(size, align),
93            };
94
95            if next > self.limit.as_ptr() as usize {
96                // Arena exhausted
97                return self.handle_exhaustion(size, align);
98            }
99
100            if self
101                .cursor
102                .compare_exchange_weak(current, next, Ordering::AcqRel, Ordering::Relaxed)
103                .is_ok()
104            {
105                return aligned as *mut u8;
106            }
107            // Contention: another thread allocated concurrently. Retry.
108        }
109    }
110
111    /// Handle arena exhaustion - either fallback or return null.
112    #[cold]
113    #[inline(never)]
114    fn handle_exhaustion(&self, size: usize, align: usize) -> *mut u8 {
115        #[cfg(debug_assertions)]
116        {
117            eprintln!(
118                "[nalloc] Arena exhausted: requested {} bytes (align {}), remaining {} bytes",
119                size,
120                align,
121                self.remaining()
122            );
123        }
124
125        #[cfg(feature = "fallback")]
126        {
127            // Fall back to system allocator
128            let layout = match Layout::from_size_align(size, align) {
129                Ok(l) => l,
130                Err(_) => return std::ptr::null_mut(),
131            };
132
133            let ptr = unsafe { System.alloc(layout) };
134
135            if !ptr.is_null() {
136                self.fallback_count.fetch_add(1, Ordering::Relaxed);
137                self.fallback_bytes.fetch_add(size, Ordering::Relaxed);
138
139                #[cfg(debug_assertions)]
140                eprintln!("[nalloc] Fallback allocation: {} bytes", size);
141            }
142
143            ptr
144        }
145
146        #[cfg(not(feature = "fallback"))]
147        {
148            std::ptr::null_mut()
149        }
150    }
151
152    /// Check if this arena has been recycled (reset after initial use).
153    ///
154    /// Uses `Acquire` ordering so that all memory writes performed by the
155    /// thread that called `reset()` (in particular the volatile zeroing in
156    /// `secure_reset`) are visible to the caller before any subsequent reads
157    /// from arena memory.  A `Relaxed` load would break the happens-before
158    /// chain with the `Release` store in `reset()`.
159    #[inline]
160    pub fn is_recycled(&self) -> bool {
161        self.is_recycled.load(Ordering::Acquire)
162    }
163
164    /// Get the number of fallback allocations (only with `fallback` feature).
165    #[cfg(feature = "fallback")]
166    #[inline]
167    pub fn fallback_count(&self) -> usize {
168        self.fallback_count.load(Ordering::Relaxed)
169    }
170
171    /// Get the total bytes allocated via fallback (only with `fallback` feature).
172    ///
173    /// **Note (Issue #9)**: This tracks the *requested* allocation size, not the actual
174    /// size allocated by the system allocator (which may be larger due to alignment
175    /// and internal bookkeeping). Use this for monitoring, not precise accounting.
176    #[cfg(feature = "fallback")]
177    #[inline]
178    pub fn fallback_bytes(&self) -> usize {
179        self.fallback_bytes.load(Ordering::Relaxed)
180    }
181
182    /// Reset the bump pointer to the base.
183    ///
184    /// # Safety
185    /// All previously allocated memory becomes invalid after this call.
186    ///
187    /// # Warning (Issue #10)
188    /// **Fallback allocations are NOT freed by reset.** When arena exhaustion triggers
189    /// fallback to the system allocator (with `fallback` feature), those allocations
190    /// must be individually deallocated via `GlobalAlloc::dealloc`. If using NAlloc
191    /// as the global allocator, this happens automatically when the memory is dropped.
192    /// However, if using arenas directly, be aware that reset only reclaims arena memory,
193    /// not system allocator memory.
194    #[inline]
195    pub unsafe fn reset(&self) {
196        self.cursor
197            .store(self.base.as_ptr() as usize, Ordering::SeqCst);
198        self.is_recycled.store(true, Ordering::Release);
199
200        #[cfg(feature = "fallback")]
201        {
202            // Reset fallback counters
203            self.fallback_count.store(0, Ordering::Relaxed);
204            self.fallback_bytes.store(0, Ordering::Relaxed);
205        }
206    }
207
208    /// Zero out all memory in the arena and reset the cursor.
209    ///
210    /// This is critical for security-sensitive applications like ZK provers,
211    /// where witness data must be wiped after use to prevent leakage.
212    ///
213    /// Uses volatile writes to prevent the compiler from optimizing away
214    /// the zeroing operation (dead store elimination).
215    ///
216    /// # Safety
217    /// All previously allocated memory becomes invalid after this call.
218    #[inline]
219    pub unsafe fn secure_reset(&self) {
220        let base = self.base.as_ptr();
221        let size = self.limit.as_ptr() as usize - base as usize;
222
223        // Use volatile writes to prevent dead store elimination.
224        // This ensures the memory is actually zeroed even if it's never read again.
225        Self::volatile_memset(base, SECURE_WIPE_PATTERN, size);
226
227        // Issue #5: Full memory barrier for multi-threaded safety.
228        // compiler_fence only prevents compiler reordering, not CPU reordering.
229        // Using atomic fence ensures other threads observe the zeroed memory
230        // before seeing the reset cursor.
231        fence(Ordering::SeqCst);
232
233        self.reset();
234    }
235
236    /// Volatile memset implementation that cannot be optimized away.
237    ///
238    /// This is critical for cryptographic security - we need to guarantee
239    /// that sensitive data is actually erased from memory.
240    #[inline(never)]
241    #[allow(unreachable_code)] // Platform-specific code paths return early, making fallback unreachable on some platforms
242    unsafe fn volatile_memset(ptr: *mut u8, value: u8, len: usize) {
243        // Method 1: Use platform-specific secure zeroing where available (for value == 0)
244        #[cfg(any(target_os = "linux", target_os = "android"))]
245        if value == 0 {
246            // explicit_bzero is guaranteed not to be optimized away
247            extern "C" {
248                fn explicit_bzero(s: *mut libc::c_void, n: libc::size_t);
249            }
250            explicit_bzero(ptr as *mut libc::c_void, len);
251            return;
252        }
253
254        #[cfg(target_vendor = "apple")]
255        {
256            // memset_s is guaranteed not to be optimized away (C11)
257            // Note: memset_s supports non-zero values
258            extern "C" {
259                fn memset_s(
260                    s: *mut libc::c_void,
261                    smax: libc::size_t,
262                    c: libc::c_int,
263                    n: libc::size_t,
264                ) -> libc::c_int;
265            }
266            let _ = memset_s(ptr as *mut libc::c_void, len, value as libc::c_int, len);
267            return;
268        }
269
270        #[cfg(target_os = "windows")]
271        if value == 0 {
272            // RtlSecureZeroMemory is guaranteed not to be optimized away
273            extern "system" {
274                fn RtlSecureZeroMemory(ptr: *mut u8, len: usize);
275            }
276            RtlSecureZeroMemory(ptr, len);
277            return;
278        }
279
280        // Issue #4: Generic volatile write loop for:
281        // - Non-zero values on Linux/Android/Windows (platform APIs only handle zero)
282        // - All values on other platforms
283        // Using usize-sized writes for better performance
284        let ptr_usize = ptr as *mut usize;
285        let pattern_usize = if value == 0 {
286            0usize
287        } else {
288            let mut p = 0usize;
289            for i in 0..std::mem::size_of::<usize>() {
290                p |= (value as usize) << (i * 8);
291            }
292            p
293        };
294
295        let full_words = len / std::mem::size_of::<usize>();
296        let remainder = len % std::mem::size_of::<usize>();
297
298        // Write full usize words
299        for i in 0..full_words {
300            std::ptr::write_volatile(ptr_usize.add(i), pattern_usize);
301        }
302
303        // Write remaining bytes
304        let remainder_ptr = ptr.add(full_words * std::mem::size_of::<usize>());
305        for i in 0..remainder {
306            std::ptr::write_volatile(remainder_ptr.add(i), value);
307        }
308    }
309
310    /// Returns the total capacity in bytes.
311    #[inline]
312    pub fn capacity(&self) -> usize {
313        self.limit.as_ptr() as usize - self.base.as_ptr() as usize
314    }
315
316    /// Returns the number of bytes currently allocated.
317    #[inline]
318    pub fn used(&self) -> usize {
319        self.cursor.load(Ordering::Relaxed) - self.base.as_ptr() as usize
320    }
321
322    /// Returns the number of bytes remaining.
323    #[inline]
324    pub fn remaining(&self) -> usize {
325        self.capacity() - self.used()
326    }
327}
328
329// Safety: BumpAlloc can be shared across threads because:
330// - `base` and `limit` are never modified after construction
331// - `cursor` uses atomic operations for thread-safe updates
332// - `is_recycled` uses atomic operations
333unsafe impl Send for BumpAlloc {}
334unsafe impl Sync for BumpAlloc {}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_nonnull_safety() {
342        let mut buffer = vec![0u8; 1024];
343        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
344
345        assert_eq!(alloc.capacity(), 1024);
346        assert_eq!(alloc.used(), 0);
347        assert_eq!(alloc.remaining(), 1024);
348        assert!(!alloc.is_recycled());
349    }
350
351    #[test]
352    fn test_recycled_flag() {
353        let mut buffer = vec![0u8; 1024];
354        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
355
356        assert!(!alloc.is_recycled());
357
358        let _ = alloc.alloc(64, 8);
359        assert!(!alloc.is_recycled());
360
361        unsafe { alloc.reset() };
362        assert!(alloc.is_recycled());
363    }
364
365    #[test]
366    fn test_secure_reset_zeroes_memory() {
367        let mut buffer = vec![0xFFu8; 1024];
368        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
369
370        // Allocate and write data
371        let ptr = alloc.alloc(512, 8);
372        assert!(!ptr.is_null());
373        unsafe {
374            std::ptr::write_bytes(ptr, 0xAB, 512);
375        }
376
377        // Secure reset
378        unsafe { alloc.secure_reset() };
379
380        // Verify memory is zeroed
381        for i in 0..1024 {
382            assert_eq!(buffer[i], 0, "Byte {} not zeroed", i);
383        }
384    }
385
386    #[test]
387    fn test_alignment() {
388        let mut buffer = vec![0u8; 4096];
389        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
390
391        // Test various alignments
392        for align_pow in 0..8 {
393            let align = 1usize << align_pow;
394            let ptr = alloc.alloc(64, align);
395            assert!(!ptr.is_null());
396            assert_eq!((ptr as usize) % align, 0, "Alignment {} failed", align);
397        }
398    }
399
400    #[test]
401    #[cfg(feature = "fallback")]
402    fn test_fallback_allocation() {
403        // Create a tiny arena that will exhaust quickly
404        let mut buffer = vec![0u8; 256];
405        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
406
407        // Fill the arena
408        let _ = alloc.alloc(256, 1);
409
410        // This should trigger fallback
411        let ptr = alloc.alloc(64, 8);
412        assert!(!ptr.is_null(), "Fallback allocation should succeed");
413
414        assert!(alloc.fallback_count() > 0, "Fallback count should increase");
415        assert!(alloc.fallback_bytes() >= 64, "Fallback bytes should track");
416
417        // Don't forget to free the fallback allocation
418        unsafe {
419            System.dealloc(ptr, Layout::from_size_align(64, 8).unwrap());
420        }
421    }
422
423    #[test]
424    #[cfg(not(feature = "fallback"))]
425    fn test_exhaustion_returns_null() {
426        let mut buffer = vec![0u8; 256];
427        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
428
429        // Fill the arena
430        let _ = alloc.alloc(256, 1);
431
432        // This should return null without fallback
433        let ptr = alloc.alloc(64, 8);
434        assert!(
435            ptr.is_null(),
436            "Should return null when exhausted without fallback"
437        );
438    }
439}