Skip to main content

scirs2_core/memory_pool/
async_pool.rs

1//! Async-style allocation queue with priority scheduling and memory pressure callbacks.
2//!
3//! `AsyncPool` wraps an `ArenaAllocator` and adds:
4//!
5//! * A priority queue of pending allocation requests (simulated asynchronously —
6//!   processing is explicit and synchronous, but the API mirrors typical GPU
7//!   stream-submission patterns).
8//! * Memory-pressure callbacks fired when the fragmentation ratio exceeds a
9//!   registered threshold.
10//! * Simple throughput and latency tracking for profiling.
11
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14use std::collections::HashMap;
15use std::time::Instant;
16
17use super::arena::ArenaAllocator;
18use super::types::{
19    AllocError, AllocationId, AllocationStats, AsyncAllocRequest, PoolConfig, RequestHandle,
20};
21
22// ---------------------------------------------------------------------------
23// Priority-queue wrapper
24// ---------------------------------------------------------------------------
25
26/// Internal entry stored in the `BinaryHeap`.
27///
28/// We want *higher* priority values to be dequeued first; `BinaryHeap` is a
29/// max-heap, so we can rely on that directly.
30#[derive(Debug)]
31struct PendingRequest {
32    priority: u8,
33    /// Monotonically-increasing sequence number used to break ties (FIFO within
34    /// the same priority level — lower sequence number wins, so we invert the
35    /// comparison).
36    sequence: u64,
37    handle: RequestHandle,
38    request: AsyncAllocRequest,
39}
40
41impl PartialEq for PendingRequest {
42    fn eq(&self, other: &Self) -> bool {
43        self.priority == other.priority && self.sequence == other.sequence
44    }
45}
46
47impl Eq for PendingRequest {}
48
49impl PartialOrd for PendingRequest {
50    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
51        Some(self.cmp(other))
52    }
53}
54
55impl Ord for PendingRequest {
56    fn cmp(&self, other: &Self) -> Ordering {
57        // Higher priority first; ties broken by lower sequence number first.
58        self.priority
59            .cmp(&other.priority)
60            .then_with(|| other.sequence.cmp(&self.sequence))
61    }
62}
63
64// ---------------------------------------------------------------------------
65// Result tracking
66// ---------------------------------------------------------------------------
67
68/// Result state for a submitted request.
69#[derive(Debug)]
70enum RequestState {
71    Pending,
72    Ready(AllocationId),
73    Failed(AllocError),
74}
75
76// ---------------------------------------------------------------------------
77// Pressure callback entry
78// ---------------------------------------------------------------------------
79
80struct PressureCallback {
81    threshold: f64,
82    callback: Box<dyn Fn(f64) + Send>,
83}
84
85// ---------------------------------------------------------------------------
86// Latency tracking (ring buffer of recent timings)
87// ---------------------------------------------------------------------------
88
89const LATENCY_WINDOW: usize = 256;
90
91struct LatencyTracker {
92    samples_ns: Vec<u64>,
93    head: usize,
94    count: usize,
95    total_ops: u64,
96    window_start: Instant,
97    window_ops: u64,
98}
99
100impl LatencyTracker {
101    fn new() -> Self {
102        Self {
103            samples_ns: vec![0u64; LATENCY_WINDOW],
104            head: 0,
105            count: 0,
106            total_ops: 0,
107            window_start: Instant::now(),
108            window_ops: 0,
109        }
110    }
111
112    fn record(&mut self, latency_ns: u64) {
113        self.samples_ns[self.head] = latency_ns;
114        self.head = (self.head + 1) % LATENCY_WINDOW;
115        if self.count < LATENCY_WINDOW {
116            self.count += 1;
117        }
118        self.total_ops += 1;
119        self.window_ops += 1;
120    }
121
122    fn avg_latency_ns(&self) -> f64 {
123        if self.count == 0 {
124            return 0.0;
125        }
126        let sum: u64 = self.samples_ns[..self.count].iter().sum();
127        sum as f64 / self.count as f64
128    }
129
130    fn throughput_ops_per_sec(&mut self) -> f64 {
131        let elapsed = self.window_start.elapsed().as_secs_f64();
132        if elapsed < 1e-9 {
133            return 0.0;
134        }
135        let ops = self.window_ops;
136        // Reset window every call.
137        self.window_start = Instant::now();
138        self.window_ops = 0;
139        ops as f64 / elapsed
140    }
141}
142
143// ---------------------------------------------------------------------------
144// AsyncPool
145// ---------------------------------------------------------------------------
146
147/// A priority-queue–driven allocation pool with memory pressure callbacks.
148pub struct AsyncPool {
149    /// Underlying arena allocator.
150    arena: ArenaAllocator,
151    /// Priority queue of pending requests.
152    pending: BinaryHeap<PendingRequest>,
153    /// Results for all submitted requests.
154    results: HashMap<usize, RequestState>,
155    /// Maximum capacity of the pending queue.
156    queue_capacity: usize,
157    /// Monotonically increasing sequence counter for FIFO tie-breaking.
158    sequence: u64,
159    /// Next handle id.
160    next_handle: usize,
161    /// Memory-pressure callbacks.
162    pressure_callbacks: Vec<PressureCallback>,
163    /// Latency / throughput tracking.
164    latency: LatencyTracker,
165}
166
167impl AsyncPool {
168    /// Create a new `AsyncPool` from the given configuration.
169    pub fn new(config: PoolConfig) -> Self {
170        let queue_capacity = config.async_queue_size;
171        let arena = ArenaAllocator::new(config);
172        Self {
173            arena,
174            pending: BinaryHeap::new(),
175            results: HashMap::new(),
176            queue_capacity,
177            sequence: 0,
178            next_handle: 0,
179            pressure_callbacks: Vec::new(),
180            latency: LatencyTracker::new(),
181        }
182    }
183
184    // -----------------------------------------------------------------------
185    // Queue submission
186    // -----------------------------------------------------------------------
187
188    /// Submit an allocation request to the queue.
189    ///
190    /// Returns `Err(AllocError::PoolFull)` if the queue is at capacity.
191    pub fn enqueue(&mut self, req: AsyncAllocRequest) -> Result<RequestHandle, AllocError> {
192        if self.pending.len() >= self.queue_capacity {
193            return Err(AllocError::PoolFull);
194        }
195
196        let handle = RequestHandle(self.next_handle);
197        self.next_handle += 1;
198
199        let entry = PendingRequest {
200            priority: req.priority,
201            sequence: self.sequence,
202            handle,
203            request: req,
204        };
205        self.sequence += 1;
206        self.pending.push(entry);
207        self.results.insert(handle.0, RequestState::Pending);
208
209        Ok(handle)
210    }
211
212    // -----------------------------------------------------------------------
213    // Queue processing
214    // -----------------------------------------------------------------------
215
216    /// Process up to `max_allocations` pending requests, highest-priority first.
217    ///
218    /// Returns a `Vec` of `(RequestHandle, AllocationId)` pairs for every
219    /// request that completed successfully.  Failed requests are recorded in
220    /// the result map and can be queried via `get_result`.
221    pub fn process_queue(&mut self, max_allocations: usize) -> Vec<(RequestHandle, AllocationId)> {
222        let mut completed = Vec::new();
223
224        for _ in 0..max_allocations {
225            let entry = match self.pending.pop() {
226                Some(e) => e,
227                None => break,
228            };
229
230            let t0 = Instant::now();
231            let result = self
232                .arena
233                .alloc(entry.request.size, entry.request.alignment);
234            let latency_ns = t0.elapsed().as_nanos() as u64;
235            self.latency.record(latency_ns);
236
237            match result {
238                Ok(id) => {
239                    self.results.insert(entry.handle.0, RequestState::Ready(id));
240                    completed.push((entry.handle, id));
241                }
242                Err(e) => {
243                    self.results.insert(entry.handle.0, RequestState::Failed(e));
244                }
245            }
246        }
247
248        self.check_pressure();
249        completed
250    }
251
252    // -----------------------------------------------------------------------
253    // Result inspection
254    // -----------------------------------------------------------------------
255
256    /// Return `true` if the given request has been processed (successfully or not).
257    pub fn is_ready(&self, handle: RequestHandle) -> bool {
258        matches!(
259            self.results.get(&handle.0),
260            Some(RequestState::Ready(_)) | Some(RequestState::Failed(_))
261        )
262    }
263
264    /// Return the `AllocationId` for a successfully completed request, or `None`
265    /// if the request is still pending or failed.
266    pub fn get_result(&self, handle: RequestHandle) -> Option<AllocationId> {
267        match self.results.get(&handle.0) {
268            Some(RequestState::Ready(id)) => Some(*id),
269            _ => None,
270        }
271    }
272
273    /// Return the error for a failed request, or `None` if pending / successful.
274    pub fn get_error(&self, handle: RequestHandle) -> Option<&AllocError> {
275        match self.results.get(&handle.0) {
276            Some(RequestState::Failed(e)) => Some(e),
277            _ => None,
278        }
279    }
280
281    // -----------------------------------------------------------------------
282    // Memory pressure
283    // -----------------------------------------------------------------------
284
285    /// Register a callback to be fired when the pool fragmentation exceeds `threshold`.
286    ///
287    /// The callback receives the current fragmentation score.
288    pub fn register_pressure_callback(&mut self, threshold: f64, cb: Box<dyn Fn(f64) + Send>) {
289        self.pressure_callbacks.push(PressureCallback {
290            threshold,
291            callback: cb,
292        });
293    }
294
295    /// Evaluate all registered pressure thresholds and fire callbacks if exceeded.
296    pub fn check_pressure(&self) {
297        let stats = self.arena.stats();
298        let score = stats.fragmentation;
299
300        for entry in &self.pressure_callbacks {
301            if score > entry.threshold {
302                (entry.callback)(score);
303            }
304        }
305    }
306
307    // -----------------------------------------------------------------------
308    // Statistics & profiling
309    // -----------------------------------------------------------------------
310
311    /// Return the allocation throughput in operations per second since the last call.
312    pub fn throughput_ops_per_sec(&mut self) -> f64 {
313        self.latency.throughput_ops_per_sec()
314    }
315
316    /// Return the average allocation latency in nanoseconds over the last
317    /// `min(n_completed, 256)` operations.
318    pub fn avg_alloc_latency_ns(&self) -> f64 {
319        self.latency.avg_latency_ns()
320    }
321
322    /// Return a snapshot of the arena statistics.
323    pub fn stats(&self) -> AllocationStats {
324        self.arena.stats()
325    }
326
327    /// Return the number of pending (not-yet-processed) requests.
328    pub fn pending_count(&self) -> usize {
329        self.pending.len()
330    }
331
332    // -----------------------------------------------------------------------
333    // Pass-through arena access
334    // -----------------------------------------------------------------------
335
336    /// Free an allocation by id in the underlying arena.
337    pub fn free(&mut self, id: AllocationId) -> Result<(), AllocError> {
338        self.arena.free(id)
339    }
340
341    /// Access the underlying arena (read-only).
342    pub fn arena(&self) -> &ArenaAllocator {
343        &self.arena
344    }
345
346    /// Access the underlying arena (mutable).
347    pub fn arena_mut(&mut self) -> &mut ArenaAllocator {
348        &mut self.arena
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::memory_pool::types::{AsyncAllocRequest, PoolConfig};
356    use std::sync::{Arc, Mutex};
357
358    fn small_pool() -> AsyncPool {
359        AsyncPool::new(PoolConfig {
360            total_size: 1024 * 1024, // 1 MiB
361            async_queue_size: 16,
362            ..Default::default()
363        })
364    }
365
366    #[test]
367    fn test_async_pool_enqueue() {
368        let mut pool = small_pool();
369        let req = AsyncAllocRequest::new(256, 5);
370        let handle = pool.enqueue(req).expect("enqueue");
371        assert!(
372            !pool.is_ready(handle),
373            "should be pending before processing"
374        );
375        assert_eq!(pool.pending_count(), 1);
376    }
377
378    #[test]
379    fn test_async_pool_priority() {
380        let mut pool = small_pool();
381        // Enqueue low-priority first, then high-priority.
382        let low_req = AsyncAllocRequest::new(64, 1);
383        let high_req = AsyncAllocRequest::new(64, 10);
384        let _low_handle = pool.enqueue(low_req).expect("enqueue low");
385        let high_handle = pool.enqueue(high_req).expect("enqueue high");
386
387        // Process only 1 request — it should be the high-priority one.
388        let completed = pool.process_queue(1);
389        assert_eq!(completed.len(), 1, "one request should complete");
390        assert_eq!(
391            completed[0].0, high_handle,
392            "high-priority request should complete first"
393        );
394    }
395
396    #[test]
397    fn test_async_pool_process() {
398        let mut pool = small_pool();
399        for _ in 0..5 {
400            let req = AsyncAllocRequest::new(128, 5);
401            pool.enqueue(req).expect("enqueue");
402        }
403        let completed = pool.process_queue(5);
404        assert_eq!(completed.len(), 5);
405        for (handle, _id) in &completed {
406            assert!(pool.is_ready(*handle));
407            assert!(pool.get_result(*handle).is_some());
408        }
409    }
410
411    #[test]
412    fn test_pressure_callback() {
413        let mut pool = small_pool();
414        let fired = Arc::new(Mutex::new(false));
415        let fired_clone = Arc::clone(&fired);
416
417        // Register a callback at threshold 0.0 (always fires when fragmentation > 0).
418        pool.register_pressure_callback(
419            -0.1, // threshold below any possible score → always fires
420            Box::new(move |_score| {
421                let mut f = fired_clone.lock().expect("lock");
422                *f = true;
423            }),
424        );
425
426        // Allocate and free to create some fragmentation.
427        let req = AsyncAllocRequest::new(64, 5);
428        let handle = pool.enqueue(req).expect("enqueue");
429        let completed = pool.process_queue(1);
430        assert_eq!(completed.len(), 1);
431        pool.free(completed[0].1).expect("free");
432
433        // Enqueue another and process — check_pressure is called in process_queue.
434        let req2 = AsyncAllocRequest::new(64, 5);
435        pool.enqueue(req2).expect("enqueue 2");
436        pool.process_queue(1);
437
438        // Even if fragmentation is 0, the threshold -0.1 ensures the callback fires.
439        // call check_pressure directly to be sure.
440        pool.check_pressure();
441
442        let was_fired = *fired.lock().expect("lock");
443        assert!(was_fired, "pressure callback should have fired");
444        let _ = handle; // suppress unused warning
445    }
446
447    #[test]
448    fn test_pool_config_default() {
449        let config = PoolConfig::default();
450        assert_eq!(config.total_size, 64 * 1024 * 1024);
451        assert_eq!(config.min_block_size, 64);
452        assert_eq!(config.alignment, 256);
453        assert!((config.defrag_threshold - 0.4).abs() < 1e-9);
454        assert_eq!(config.async_queue_size, 1024);
455    }
456}