Skip to main content

zk_alloc/
lib.rs

1//! Bump-pointer arena allocator for ZK proving workloads.
2//!
3//! # Two-allocator model
4//!
5//! `ZkAllocator` is a façade over two allocators selected per call:
6//!
7//! - **Arena**: one `mmap` region split into per-thread slabs. Allocation
8//!   bumps a thread-local pointer; `dealloc` is a no-op. `begin_phase()`
9//!   resets every slab so the next phase reuses the same physical pages.
10//! - **System**: `std::alloc::System` (glibc on Linux). Used for everything
11//!   the arena shouldn't hold:
12//!   - any allocation when no phase is active;
13//!   - any allocation smaller than [`min_arena_bytes()`] even during a phase
14//!     (size-routing — keeps small library bookkeeping outside the arena);
15//!   - oversize allocations or threads that arrived after slabs were claimed
16//!     ([`overflow_stats()`] reports these);
17//!   - regrowth via `realloc` of a pointer that was already in System
18//!     (sticky-System routing — System allocations don't migrate to arena
19//!     on growth, even if the new size exceeds the size-routing threshold).
20//!
21//! # Phase scoping contract
22//!
23//! `begin_phase()` activates the arena and resets every slab. `end_phase()`
24//! deactivates the arena. Allocations made during phase N must not be held
25//! past `begin_phase()` of phase N+1: that call recycles the slab, and the
26//! next allocation at the same offset will silently overwrite the retained
27//! bytes.
28//!
29//! Practical rules:
30//!
31//! 1. Drop or `clone()` arena-allocated values before the phase ends.
32//! 2. Use [`PhaseGuard`] / [`phase`] to ensure `end_phase` runs even on
33//!    panic — without it, an unwinding phase leaves the arena active and
34//!    subsequent "post-phase" allocations land in arena territory.
35//! 3. Keep long-lived state (thread pools, channels, registries, caches)
36//!    constructed *outside* any active phase so it lives in System.
37//!
38//! # Realloc migration: prevented
39//!
40//! `realloc` checks whether the input pointer lies in the arena region.
41//! If it does, growth goes through the normal arena path (subject to
42//! size-routing). If it does not, growth stays in System via
43//! `System::realloc` — preventing the failure mode where a System-backed
44//! `Vec` silently migrates into the arena on `push`.
45//!
46//! # Configuration
47//!
48//! - `ZK_ALLOC_SLAB_GB` — per-thread slab size in GiB (default `8`).
49//! - `ZK_ALLOC_MIN_BYTES` — size-routing threshold in bytes (default `4096`).
50//!   Set to `0` to send every active-phase allocation to the arena.
51//!
52//! # Example
53//!
54//! ```ignore
55//! use zk_alloc::ZkAllocator;
56//!
57//! #[global_allocator]
58//! static ALLOC: ZkAllocator = ZkAllocator;
59//!
60//! loop {
61//!     let proof = zk_alloc::phase(|| heavy_work()); // arena on inside
62//!     let output = proof.clone();                   // detach into System
63//!     submit(output);
64//! }
65//! ```
66
67use std::alloc::{GlobalAlloc, Layout};
68use std::cell::Cell;
69use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
70use std::sync::Once;
71
72mod syscall;
73
74const DEFAULT_SLAB_GB: usize = 8;
75const SLACK: usize = 4;
76
77#[derive(Debug)]
78pub struct ZkAllocator;
79
80/// Per-thread slab size in bytes. Set once during `ensure_region()` from the
81/// `ZK_ALLOC_SLAB_GB` environment variable (default: 8).
82static SLAB_SIZE: AtomicUsize = AtomicUsize::new(0);
83
84/// Incremented by `begin_phase()`. Every thread caches the last value it saw in
85/// `ARENA_GEN`; when they differ, the thread resets its allocation cursor to the start
86/// of its slab on the next allocation. This is how a single store on the main thread
87/// "resets" every other thread's slab without any cross-thread synchronization.
88static GENERATION: AtomicUsize = AtomicUsize::new(0);
89
90/// Master switch for the arena. `true` (set by `begin_phase`) routes allocations
91/// through the arena; `false` (set by `end_phase`) routes them to the system allocator.
92static ARENA_ACTIVE: AtomicBool = AtomicBool::new(false);
93
94/// Base address of the mmap'd region, or `0` before `ensure_region` runs. Read on
95/// every `dealloc` to test whether a pointer belongs to us.
96static REGION_BASE: AtomicUsize = AtomicUsize::new(0);
97
98/// Total size of the mmap'd region. Set once alongside REGION_BASE.
99static REGION_SIZE: AtomicUsize = AtomicUsize::new(0);
100
101/// Synchronizes the one-time mmap so concurrent first-allocators don't race.
102static REGION_INIT: Once = Once::new();
103
104/// Monotonic counter handed out to threads to pick their slab. `fetch_add`'d once per
105/// thread on its first arena allocation. Threads that get `idx >= max_threads` mark
106/// themselves `ARENA_NO_SLAB` and permanently fall through to the system allocator.
107static THREAD_IDX: AtomicUsize = AtomicUsize::new(0);
108
109/// Max threads determined at init time from available_parallelism() + SLACK.
110static MAX_THREADS: AtomicUsize = AtomicUsize::new(0);
111
112static OVERFLOW_COUNT: AtomicUsize = AtomicUsize::new(0);
113static OVERFLOW_BYTES: AtomicUsize = AtomicUsize::new(0);
114
115/// Allocations smaller than this go to System even during active phases.
116/// Routes registry / hashmap / injector-block-sized allocations away from
117/// the arena, so library state that outlives a phase doesn't land in
118/// recycled memory.
119///
120/// Defaults to 4096 (one page) — covers the known phase-crossing patterns:
121/// crossbeam_deque::Injector blocks (~1.5 KB), tracing-subscriber Registry
122/// slot data (sub-KB), hashbrown HashMap entries (sub-KB), rayon-core job
123/// stack frames (sub-KB). Set ZK_ALLOC_MIN_BYTES=0 to disable, or override
124/// to a different threshold.
125const DEFAULT_MIN_ARENA_BYTES: usize = 4096;
126static MIN_ARENA_BYTES: AtomicUsize = AtomicUsize::new(DEFAULT_MIN_ARENA_BYTES);
127
128thread_local! {
129    /// Where this thread's next allocation lands. Advanced past each allocation.
130    static ARENA_PTR: Cell<usize> = const { Cell::new(0) };
131    /// One past the last byte of this thread's slab.
132    static ARENA_END: Cell<usize> = const { Cell::new(0) };
133    /// Base address of this thread's slab (`0` = not yet claimed).
134    static ARENA_BASE: Cell<usize> = const { Cell::new(0) };
135    /// Last `GENERATION` value this thread observed.
136    static ARENA_GEN: Cell<usize> = const { Cell::new(0) };
137    /// `true` if this thread arrived after all slabs were claimed.
138    static ARENA_NO_SLAB: Cell<bool> = const { Cell::new(false) };
139}
140
141fn ensure_region() -> usize {
142    REGION_INIT.call_once(|| {
143        let slab_gb = std::env::var("ZK_ALLOC_SLAB_GB")
144            .ok()
145            .and_then(|s| s.parse::<usize>().ok())
146            .unwrap_or(DEFAULT_SLAB_GB);
147        let slab_size = slab_gb << 30;
148        SLAB_SIZE.store(slab_size, Ordering::Release);
149
150        if let Ok(s) = std::env::var("ZK_ALLOC_MIN_BYTES") {
151            if let Ok(n) = s.parse::<usize>() {
152                MIN_ARENA_BYTES.store(n, Ordering::Release);
153            }
154        }
155
156        let cpus = std::thread::available_parallelism()
157            .map(|n| n.get())
158            .unwrap_or(8);
159        let max_threads = cpus + SLACK;
160        let region_size = slab_size * max_threads;
161
162        // SAFETY: mmap_anonymous returns a page-aligned pointer or null.
163        // MAP_NORESERVE means no physical memory is committed until pages are touched.
164        let ptr = unsafe { syscall::mmap_anonymous(region_size) };
165        if ptr.is_null() {
166            std::process::abort();
167        }
168        unsafe { syscall::madvise(ptr, region_size, syscall::MADV_NOHUGEPAGE) };
169        MAX_THREADS.store(max_threads, Ordering::Release);
170        REGION_SIZE.store(region_size, Ordering::Release);
171        REGION_BASE.store(ptr as usize, Ordering::Release);
172    });
173    REGION_BASE.load(Ordering::Acquire)
174}
175
176/// Activates the arena and resets every thread's slab. All allocations until the next
177/// `end_phase()` go to the arena; the previous phase's data is overwritten in place.
178///
179/// ## Phases must not nest
180///
181/// Calling `begin_phase()` while another phase is already active panics. The
182/// arena is a flat lifetime — nested phases were previously tolerated via a
183/// depth counter, but the depth counter masked correctness bugs (panics
184/// orphaning the count, accidental double-begin recycling the outer phase's
185/// slab on the next allocation). The contract is now: every `begin_phase()`
186/// is paired with one `end_phase()` (or use [`PhaseGuard`] / [`phase`] for
187/// panic-safe pairing), and no second `begin_phase()` is reachable from
188/// within an active phase.
189///
190/// ## Retention is unsafe
191///
192/// Allocations made during phase N that are still held when phase N+1 begins
193/// are silently overwritten by phase N+1's first allocations at the same slab
194/// offset. Any of the following held across `begin_phase()` will be corrupted:
195///
196/// - `Vec<T>` with capacity ≥ [`min_arena_bytes()`] (`push` triggers `realloc`
197///   that copies from now-recycled source memory).
198/// - `Arc<T>` / `Rc<T>` with payload ≥ [`min_arena_bytes()`] (refcount fields
199///   become arbitrary bytes — silent leak or use-after-free).
200/// - `HashMap`, `BTreeMap`, etc. with bucket allocation ≥ [`min_arena_bytes()`]
201///   (lookup may infinite-loop on corrupted ctrl bytes).
202/// - `Box<dyn Trait>` with backing data ≥ [`min_arena_bytes()`] (vtable
203///   dispatch survives but field reads return filler bytes).
204///
205/// To preserve data across phases, `clone()` it into a System-backed copy
206/// (e.g., wrap in `Box::leak(Box::new(...))` while ARENA_ACTIVE is false,
207/// or copy into a `Vec` allocated outside any phase).
208pub fn begin_phase() {
209    ensure_region();
210    let prev_active = ARENA_ACTIVE.swap(true, Ordering::Release);
211    assert!(
212        !prev_active,
213        "begin_phase() called while another phase is already active — phases must not nest"
214    );
215    GENERATION.fetch_add(1, Ordering::Release);
216}
217
218/// Deactivates the arena. New allocations go to the system allocator; existing arena
219/// pointers stay valid until the next `begin_phase()` resets the slabs.
220///
221/// With the `rayon-flush` feature (default), this also drains rayon's internal
222/// queues to release any crossbeam-deque blocks allocated during the phase.
223///
224/// Idempotent: calling `end_phase()` while no phase is active is a no-op.
225pub fn end_phase() {
226    ARENA_ACTIVE.store(false, Ordering::Release);
227    #[cfg(feature = "rayon-flush")]
228    flush_rayon();
229}
230
231/// Drains rayon's crossbeam-deque injector to release blocks allocated during
232/// the active phase. Without this, `begin_phase()` would recycle memory that
233/// rayon's injector still references, causing silent corruption.
234///
235/// Pushes `FLUSH_JOBS` no-op joins. Each consumes one injector slot; once a
236/// block's last slot is consumed, crossbeam deallocates it. The fresh tail
237/// block lands in the system allocator (arena is already inactive).
238#[cfg(feature = "rayon-flush")]
239fn flush_rayon() {
240    const FLUSH_JOBS: usize = 256;
241    for _ in 0..FLUSH_JOBS {
242        rayon::join(|| {}, || {});
243    }
244}
245
246/// RAII guard for an arena phase. Calls `begin_phase()` on construction and
247/// `end_phase()` on drop — including during panic unwinding. Use this in
248/// place of paired `begin_phase()`/`end_phase()` calls when the phase body
249/// can panic, to avoid leaving the arena active across the unwind.
250///
251/// ```ignore
252/// loop {
253///     let _guard = zk_alloc::PhaseGuard::new();
254///     heavy_work_that_might_panic();
255///     // _guard drops here on normal return AND on unwind
256/// }
257/// ```
258pub struct PhaseGuard {
259    _private: (),
260}
261
262impl PhaseGuard {
263    /// Begins a phase. The phase ends when the returned guard is dropped.
264    pub fn new() -> Self {
265        begin_phase();
266        Self { _private: () }
267    }
268}
269
270impl Default for PhaseGuard {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276impl Drop for PhaseGuard {
277    fn drop(&mut self) {
278        end_phase();
279    }
280}
281
282/// Runs `f` inside a phase. Equivalent to constructing a `PhaseGuard`,
283/// running `f`, and dropping the guard. Panics in `f` propagate, but the
284/// phase is guaranteed to end before unwinding leaves this function.
285pub fn phase<F, R>(f: F) -> R
286where
287    F: FnOnce() -> R,
288{
289    let _guard = PhaseGuard::new();
290    f()
291}
292
293/// Returns (overflow_count, overflow_bytes) — allocations that fell through to System
294/// because they exceeded the slab or arrived after all slabs were claimed.
295pub fn overflow_stats() -> (usize, usize) {
296    (
297        OVERFLOW_COUNT.load(Ordering::Relaxed),
298        OVERFLOW_BYTES.load(Ordering::Relaxed),
299    )
300}
301
302pub fn reset_overflow_stats() {
303    OVERFLOW_COUNT.store(0, Ordering::Relaxed);
304    OVERFLOW_BYTES.store(0, Ordering::Relaxed);
305}
306
307/// Returns the per-thread slab size in bytes. Zero before the first `begin_phase()`.
308pub fn slab_size() -> usize {
309    SLAB_SIZE.load(Ordering::Relaxed)
310}
311
312/// Returns the minimum allocation size routed through the arena. Allocations
313/// smaller than this go to System even during active phases.
314pub fn min_arena_bytes() -> usize {
315    MIN_ARENA_BYTES.load(Ordering::Relaxed)
316}
317
318#[cold]
319#[inline(never)]
320unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 {
321    let generation = GENERATION.load(Ordering::Relaxed);
322    if !ARENA_NO_SLAB.get() && ARENA_GEN.get() != generation {
323        let mut base = ARENA_BASE.get();
324        if base == 0 {
325            let region = ensure_region();
326            let max = MAX_THREADS.load(Ordering::Relaxed);
327            let idx = THREAD_IDX.fetch_add(1, Ordering::Relaxed);
328            if idx >= max {
329                ARENA_NO_SLAB.set(true);
330                return unsafe {
331                    std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align))
332                };
333            }
334            let slab_size = SLAB_SIZE.load(Ordering::Relaxed);
335            base = region + idx * slab_size;
336            ARENA_BASE.set(base);
337            ARENA_END.set(base + slab_size);
338        }
339        ARENA_PTR.set(base);
340        ARENA_GEN.set(generation);
341        let aligned = (base + align - 1) & !(align - 1);
342        let new_ptr = aligned + size;
343        if new_ptr <= ARENA_END.get() {
344            ARENA_PTR.set(new_ptr);
345            return aligned as *mut u8;
346        }
347    }
348    OVERFLOW_COUNT.fetch_add(1, Ordering::Relaxed);
349    OVERFLOW_BYTES.fetch_add(size, Ordering::Relaxed);
350    unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) }
351}
352
353// SAFETY: All pointers returned are either from our mmap'd region (valid, aligned,
354// non-overlapping per thread) or from System. The arena is thread-local so no data
355// races. Relaxed ordering on ARENA_ACTIVE/GENERATION is sound: worst case a thread
356// sees a stale value and does one extra system-alloc before picking up the new
357// generation on the next call.
358unsafe impl GlobalAlloc for ZkAllocator {
359    #[inline(always)]
360    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
361        if ARENA_ACTIVE.load(Ordering::Relaxed) {
362            // Small allocs bypass arena: registry slots / HashMap entries /
363            // injector-block-sized allocations from rayon/tracing libraries
364            // commonly outlive a phase. Routing them to System keeps them
365            // safe across begin_phase()/end_phase() boundaries.
366            let min_bytes = MIN_ARENA_BYTES.load(Ordering::Relaxed);
367            if min_bytes != 0 && layout.size() < min_bytes {
368                return unsafe { std::alloc::System.alloc(layout) };
369            }
370            let generation = GENERATION.load(Ordering::Relaxed);
371            if ARENA_GEN.get() == generation {
372                let ptr = ARENA_PTR.get();
373                let aligned = (ptr + layout.align() - 1) & !(layout.align() - 1);
374                let new_ptr = aligned + layout.size();
375                if new_ptr <= ARENA_END.get() {
376                    ARENA_PTR.set(new_ptr);
377                    return aligned as *mut u8;
378                }
379            }
380            return unsafe { arena_alloc_cold(layout.size(), layout.align()) };
381        }
382        unsafe { std::alloc::System.alloc(layout) }
383    }
384
385    #[inline(always)]
386    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
387        let addr = ptr as usize;
388        let base = REGION_BASE.load(Ordering::Relaxed);
389        let region_size = REGION_SIZE.load(Ordering::Relaxed);
390        if base != 0 && addr >= base && addr < base + region_size {
391            return;
392        }
393        unsafe { std::alloc::System.dealloc(ptr, layout) };
394    }
395
396    #[inline(always)]
397    unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
398        if new_size <= layout.size() {
399            return ptr;
400        }
401        // Sticky-System routing: if the original allocation came from System
402        // (small, or pre-phase, or routed by size-routing), keep the grown
403        // allocation in System too. Without this, a Vec allocated outside
404        // a phase that grows inside one would silently migrate into the
405        // arena and become subject to phase recycling.
406        let addr = ptr as usize;
407        let base = REGION_BASE.load(Ordering::Relaxed);
408        let region_size = REGION_SIZE.load(Ordering::Relaxed);
409        let in_arena = base != 0 && addr >= base && addr < base + region_size;
410        if !in_arena {
411            return unsafe { std::alloc::System.realloc(ptr, layout, new_size) };
412        }
413        let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
414        let new_ptr = unsafe { self.alloc(new_layout) };
415        if !new_ptr.is_null() {
416            // Use `ptr::copy` (memmove) instead of `copy_nonoverlapping`:
417            // when reallocating an arena pointer across a phase boundary,
418            // the cold-path slab reset (or fast-path bump after reset) can
419            // hand back a pointer that aliases or partially overlaps the
420            // source. `copy_nonoverlapping` is UB on overlap; `copy`
421            // handles it correctly. Modern x86_64 memcpy implementations
422            // happen to be safe for short overlaps in practice, but the
423            // language-level UB is real and would surface under miri or
424            // future codegen.
425            unsafe { std::ptr::copy(ptr, new_ptr, layout.size()) };
426            unsafe { self.dealloc(ptr, layout) };
427        }
428        new_ptr
429    }
430}