stack_arena/
stack_arena.rs

1use std::{alloc::Layout, ptr::NonNull};
2
3use crate::Allocator;
4
5#[derive(Debug)]
6struct Chunk {
7    buffer: Box<[u8]>,
8    offset: usize,
9    len: usize,
10}
11
12impl Chunk {
13    fn new(capacity: usize) -> Self {
14        Self {
15            buffer: unsafe { Box::new_uninit_slice(capacity).assume_init() },
16            offset: 0,
17            len: 0,
18        }
19    }
20
21    fn capacity(&self) -> usize {
22        self.buffer.len()
23    }
24
25    fn object(&self) -> NonNull<[u8]> {
26        NonNull::slice_from_raw_parts(
27            unsafe { NonNull::new_unchecked(self.buffer.as_ptr().add(self.offset).cast_mut()) },
28            self.len,
29        )
30    }
31}
32
33impl From<Box<[u8]>> for Chunk {
34    fn from(value: Box<[u8]>) -> Self {
35        Self {
36            buffer: value,
37            offset: 0,
38            len: 0,
39        }
40    }
41}
42
43impl Into<Box<[u8]>> for Chunk {
44    fn into(self) -> Box<[u8]> {
45        self.buffer
46    }
47}
48
49#[derive(Debug)]
50pub struct StackArena {
51    store: Vec<Box<[u8]>>,
52    stack: Vec<NonNull<[u8]>>,
53    current: Chunk,
54}
55
56impl StackArena {
57    pub fn new() -> Self {
58        Self {
59            store: Vec::new(),
60            stack: Vec::new(),
61            current: Chunk::new(1024),
62        }
63    }
64
65    pub fn len(&self) -> usize {
66        self.stack.len()
67    }
68
69    pub fn is_empty(&self) -> bool {
70        self.len() == 0
71    }
72
73    pub fn push<P: AsRef<[u8]>>(&mut self, data: P) -> NonNull<[u8]> {
74        let data = data.as_ref();
75        let object = unsafe { self.allocate(Layout::for_value(data)) }.unwrap();
76        unsafe {
77            object.cast().copy_from_nonoverlapping(
78                NonNull::new_unchecked(data.as_ptr().cast_mut()),
79                data.len(),
80            )
81        };
82        self.finish()
83    }
84
85    pub fn pop(&mut self) {
86        debug_assert_eq!(
87            self.current.len, 0,
88            "cannot pop while having partial object"
89        );
90        let object = self.stack.pop().expect("stack underflow");
91        self.current.offset -= object.len();
92    }
93
94    pub fn extend<P: AsRef<[u8]>>(&mut self, data: P) {
95        let data = data.as_ref();
96        unsafe {
97            let len = self.current.len;
98            let ptr = self.current.buffer.as_mut_ptr().add(self.current.offset);
99            let object = self
100                .grow(
101                    NonNull::new_unchecked(ptr),
102                    Layout::array::<u8>(self.current.len).unwrap(),
103                    Layout::array::<u8>(self.current.len + data.len()).unwrap(),
104                )
105                .unwrap();
106            object.cast().add(len).copy_from_nonoverlapping(
107                NonNull::new_unchecked(data.as_ptr().cast_mut()),
108                data.len(),
109            )
110        }
111    }
112
113    pub fn finish(&mut self) -> NonNull<[u8]> {
114        let object = self.current.object();
115        self.current.offset += object.len();
116        self.current.len = 0;
117        self.stack.push(object);
118        object
119    }
120
121    pub fn free(&mut self, data: &[u8]) {
122        let data = data.as_ref();
123        unsafe {
124            self.deallocate(
125                NonNull::new_unchecked(data.as_ptr().cast_mut()),
126                Layout::for_value(data),
127            )
128        };
129    }
130}
131
132impl std::fmt::Write for StackArena {
133    fn write_str(&mut self, s: &str) -> std::fmt::Result {
134        self.extend(s);
135        Ok(())
136    }
137}
138
139impl Allocator for StackArena {
140    unsafe fn allocate(
141        &mut self,
142        layout: std::alloc::Layout,
143    ) -> Result<std::ptr::NonNull<[u8]>, crate::AllocError> {
144        debug_assert_eq!(self.current.len, 0, "use grow instead");
145        let mut offset = self.current.offset
146            + self
147                .current
148                .buffer
149                .as_ptr()
150                .add(self.current.offset)
151                .align_offset(layout.align());
152        if offset + layout.size() > self.current.capacity() {
153            let mut capacity = self.current.capacity();
154            while capacity < layout.size() {
155                capacity *= 2;
156            }
157            let old_chunk = std::mem::replace(&mut self.current, Chunk::new(capacity));
158            if old_chunk.offset != 0 {
159                self.store.push(old_chunk.into());
160            }
161            offset = self.current.buffer.as_ptr().align_offset(layout.align());
162        }
163        self.current.offset = offset;
164        self.current.len = layout.size();
165        Ok(self.current.object())
166    }
167
168    unsafe fn deallocate(&mut self, ptr: std::ptr::NonNull<u8>, layout: std::alloc::Layout) {
169        let object = NonNull::slice_from_raw_parts(ptr, layout.size());
170        while let Some(top) = self.stack.pop() {
171            // TODO: self.store.pop_if(xxx)
172            if std::ptr::eq(object.as_ptr(), top.as_ptr()) {
173                break;
174            }
175        }
176    }
177
178    unsafe fn grow(
179        &mut self,
180        ptr: NonNull<u8>,
181        old_layout: std::alloc::Layout,
182        new_layout: std::alloc::Layout,
183    ) -> Result<NonNull<[u8]>, crate::AllocError> {
184        match old_layout.size().cmp(&new_layout.size()) {
185            std::cmp::Ordering::Less => {
186                debug_assert_eq!(
187                    ptr.as_ptr().cast_const(),
188                    self.current.buffer.as_ptr().add(self.current.offset)
189                );
190                debug_assert_eq!(old_layout.size(), self.current.len);
191                debug_assert_eq!(old_layout.align(), new_layout.align());
192
193                let mut capacity = self.current.capacity();
194                if self.current.offset + new_layout.size() > capacity {
195                    while capacity < new_layout.size() {
196                        capacity *= 2;
197                    }
198                    let old_chunk = std::mem::replace(&mut self.current, Chunk::new(capacity));
199                    let object = old_chunk.object();
200                    self.current.offset = self
201                        .current
202                        .buffer
203                        .as_ptr()
204                        .align_offset(new_layout.align());
205                    self.current
206                        .buffer
207                        .as_mut_ptr()
208                        .add(self.current.offset)
209                        .copy_from_nonoverlapping(object.as_ptr().cast(), object.len());
210                    if old_chunk.offset != 0 {
211                        self.store.push(old_chunk.into());
212                    }
213                }
214                self.current.len = new_layout.size();
215                Ok(self.current.object())
216            }
217            std::cmp::Ordering::Equal => Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size())),
218            std::cmp::Ordering::Greater => panic!("use shrink instead"),
219        }
220    }
221
222    unsafe fn shrink(
223        &mut self,
224        ptr: NonNull<u8>,
225        old_layout: Layout,
226        new_layout: Layout,
227    ) -> Result<NonNull<[u8]>, crate::AllocError> {
228        match old_layout.size().cmp(&new_layout.size()) {
229            std::cmp::Ordering::Less => panic!("use grow instead"),
230            std::cmp::Ordering::Equal => Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size())),
231            std::cmp::Ordering::Greater => {
232                debug_assert_eq!(
233                    ptr.as_ptr().cast_const(),
234                    self.current.buffer.as_ptr().add(self.current.offset)
235                );
236                debug_assert_eq!(old_layout.size(), self.current.len);
237                debug_assert_eq!(old_layout.align(), new_layout.align());
238                self.current.len = new_layout.size();
239                Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size()))
240            }
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use std::fmt::Write;
248
249    use super::*;
250
251    #[test]
252    fn test_lifecycle() {
253        let mut stack = StackArena::new();
254        write!(&mut stack, "ab").expect("write");
255        let s = "c";
256        stack.extend(s);
257        let p = unsafe { stack.finish().as_ref() };
258        assert_eq!(p, b"abc");
259    }
260
261    #[test]
262    fn test_new() {
263        let stack = StackArena::new();
264        assert_eq!(stack.len(), 0);
265        assert!(stack.is_empty());
266        assert_eq!(stack.store.len(), 0);
267        assert_eq!(stack.stack.len(), 0);
268    }
269
270    #[test]
271    fn test_push_pop() {
272        let mut stack = StackArena::new();
273
274        // Test push
275        stack.push(b"hello");
276        assert_eq!(stack.len(), 1);
277        assert!(!stack.is_empty());
278
279        // Test push multiple items
280        stack.push(b"world");
281        assert_eq!(stack.len(), 2);
282
283        // Test pop
284        stack.pop();
285        assert_eq!(stack.len(), 1);
286
287        // Test pop to empty
288        stack.pop();
289        assert_eq!(stack.len(), 0);
290        assert!(stack.is_empty());
291    }
292
293    #[test]
294    fn test_extend_small_data() {
295        let mut stack = StackArena::new();
296
297        // Extend with data smaller than chunk capacity
298        stack.extend(b"hello");
299
300        // Extend again with small data
301        stack.extend(b" world");
302
303        // Finish and verify
304        let data = unsafe { stack.finish().as_ref() };
305        assert_eq!(data, b"hello world");
306    }
307
308    #[test]
309    fn test_extend_large_data() {
310        let mut stack = StackArena::new();
311
312        // Extend with data larger than chunk capacity
313        let large_data = vec![b'x'; 20];
314        stack.extend(&large_data);
315
316        // Finish and verify
317        let data = unsafe { stack.finish().as_ref() };
318        assert_eq!(data, &large_data[..]);
319    }
320
321    #[test]
322    fn test_extend_after_finish() {
323        let mut stack = StackArena::new();
324
325        // First object
326        stack.extend("first");
327        let first = unsafe { stack.finish().as_ref() };
328        assert_eq!(first, b"first");
329        eprintln!("{stack:?}");
330        // Second object
331        stack.extend(b"second");
332        let second = unsafe { stack.finish().as_ref() };
333        eprintln!("{stack:?}");
334
335        assert_eq!(second, b"second");
336
337        // Verify both objects are still valid
338        assert_eq!(first, b"first");
339        assert_eq!(second, b"second");
340    }
341
342    #[test]
343    fn test_free() {
344        let mut stack = StackArena::new();
345
346        // Create multiple objects
347        stack.push(b"first");
348        stack.extend(b"second");
349        let second = unsafe { stack.finish().as_ref() };
350        stack.extend(b"third");
351        let _third = unsafe { stack.finish().as_ref() };
352
353        // Free up to second object
354        stack.free(second);
355
356        // Verify stack state
357        assert_eq!(stack.len(), 1); // Only "first" remains
358
359        // Add a new object
360        stack.extend(b"fourth");
361        let fourth = unsafe { stack.finish().as_ref() };
362        assert_eq!(fourth, b"fourth");
363
364        // Verify stack state
365        assert_eq!(stack.len(), 2); // "first" and "fourth"
366    }
367
368    #[test]
369    fn test_write_trait() {
370        let mut stack = StackArena::new();
371
372        // Test write_str via the Write trait
373        write!(&mut stack, "Hello, {}!", "world").unwrap();
374        // Finish and verify
375        let data = unsafe { stack.finish().as_ref() };
376        assert_eq!(data, b"Hello, world!");
377    }
378
379    #[test]
380    fn test_empty_data() {
381        let mut stack = StackArena::new();
382
383        // Test with empty data
384        stack.extend(b"");
385        let data = unsafe { stack.finish().as_ref() };
386        assert_eq!(data, b"");
387
388        // Test push with empty data
389        stack.push(b"");
390        assert_eq!(stack.len(), 2);
391    }
392
393    #[test]
394    fn test_multiple_operations() {
395        let mut stack = StackArena::new();
396
397        // Mix of operations
398        stack.push(b"item1");
399        stack.extend(b"item2-part1");
400        stack.extend(b"-part2");
401        let item2 = unsafe { stack.finish().as_ref() };
402        write!(&mut stack, "item3").unwrap();
403        let item3 = unsafe { stack.finish().as_ref() };
404
405        // Verify
406        assert_eq!(item2, b"item2-part1-part2");
407        assert_eq!(item3, b"item3");
408        assert_eq!(stack.len(), 3);
409
410        // Pop and verify
411        stack.pop();
412        assert_eq!(stack.len(), 2);
413    }
414
415    #[test]
416    fn test_extend_exact_capacity() {
417        let mut stack = StackArena::new();
418
419        // Fill exactly to capacity
420        let data = vec![b'x'; 10]; // Same as chunk_size
421        stack.extend(&data);
422
423        // Add more data to trigger new allocation
424        stack.extend(b"more");
425
426        // Finish and verify
427        let result = unsafe { stack.finish().as_ref() };
428        let mut expected = data.clone();
429        expected.extend_from_slice(b"more");
430        assert_eq!(result, expected.as_slice());
431    }
432
433    #[test]
434    fn test_free_all() {
435        let mut stack = StackArena::new();
436
437        // Create multiple objects
438        stack.push(b"first");
439        stack.extend(b"second");
440        let _second = unsafe { stack.finish().as_ref() };
441
442        // Free all objects by freeing the first one
443        let first = unsafe { stack.stack[0].as_ref() };
444        stack.free(first);
445
446        // Verify stack is empty
447        assert_eq!(stack.len(), 0);
448        assert!(stack.is_empty());
449    }
450
451    #[test]
452    #[should_panic]
453    fn test_free_nonexistent() {
454        let mut stack = StackArena::new();
455
456        // Create an object
457        stack.push(b"object");
458        assert_eq!(stack.len(), 1);
459
460        // Try to free a non-existent object
461        let dummy = b"nonexistent";
462        stack.free(dummy);
463
464        // Stack should remain unchanged
465        assert_eq!(stack.len(), 1);
466    }
467}