1use std::{alloc::Layout, ptr::NonNull};
2
3use crate::Allocator;
4
5#[derive(Debug)]
6struct Chunk {
7 buffer: Box<[u8]>,
8 offset: usize,
9 len: usize,
10}
11
12impl Chunk {
13 fn new(capacity: usize) -> Self {
14 Self {
15 buffer: unsafe { Box::new_uninit_slice(capacity).assume_init() },
16 offset: 0,
17 len: 0,
18 }
19 }
20
21 fn capacity(&self) -> usize {
22 self.buffer.len()
23 }
24
25 fn object(&self) -> NonNull<[u8]> {
26 NonNull::slice_from_raw_parts(
27 unsafe { NonNull::new_unchecked(self.buffer.as_ptr().add(self.offset).cast_mut()) },
28 self.len,
29 )
30 }
31}
32
33impl From<Box<[u8]>> for Chunk {
34 fn from(value: Box<[u8]>) -> Self {
35 Self {
36 buffer: value,
37 offset: 0,
38 len: 0,
39 }
40 }
41}
42
43impl Into<Box<[u8]>> for Chunk {
44 fn into(self) -> Box<[u8]> {
45 self.buffer
46 }
47}
48
49#[derive(Debug)]
50pub struct StackArena {
51 store: Vec<Box<[u8]>>,
52 stack: Vec<NonNull<[u8]>>,
53 current: Chunk,
54}
55
56impl StackArena {
57 pub fn new() -> Self {
58 Self {
59 store: Vec::new(),
60 stack: Vec::new(),
61 current: Chunk::new(1024),
62 }
63 }
64
65 pub fn len(&self) -> usize {
66 self.stack.len()
67 }
68
69 pub fn is_empty(&self) -> bool {
70 self.len() == 0
71 }
72
73 pub fn push<P: AsRef<[u8]>>(&mut self, data: P) -> NonNull<[u8]> {
74 let data = data.as_ref();
75 let object = unsafe { self.allocate(Layout::for_value(data)) }.unwrap();
76 unsafe {
77 object.cast().copy_from_nonoverlapping(
78 NonNull::new_unchecked(data.as_ptr().cast_mut()),
79 data.len(),
80 )
81 };
82 self.finish()
83 }
84
85 pub fn pop(&mut self) {
86 debug_assert_eq!(
87 self.current.len, 0,
88 "cannot pop while having partial object"
89 );
90 let object = self.stack.pop().expect("stack underflow");
91 self.current.offset -= object.len();
92 }
93
94 pub fn extend<P: AsRef<[u8]>>(&mut self, data: P) {
95 let data = data.as_ref();
96 unsafe {
97 let len = self.current.len;
98 let ptr = self.current.buffer.as_mut_ptr().add(self.current.offset);
99 let object = self
100 .grow(
101 NonNull::new_unchecked(ptr),
102 Layout::array::<u8>(self.current.len).unwrap(),
103 Layout::array::<u8>(self.current.len + data.len()).unwrap(),
104 )
105 .unwrap();
106 object.cast().add(len).copy_from_nonoverlapping(
107 NonNull::new_unchecked(data.as_ptr().cast_mut()),
108 data.len(),
109 )
110 }
111 }
112
113 pub fn finish(&mut self) -> NonNull<[u8]> {
114 let object = self.current.object();
115 self.current.offset += object.len();
116 self.current.len = 0;
117 self.stack.push(object);
118 object
119 }
120
121 pub fn free(&mut self, data: &[u8]) {
122 let data = data.as_ref();
123 unsafe {
124 self.deallocate(
125 NonNull::new_unchecked(data.as_ptr().cast_mut()),
126 Layout::for_value(data),
127 )
128 };
129 }
130}
131
132impl std::fmt::Write for StackArena {
133 fn write_str(&mut self, s: &str) -> std::fmt::Result {
134 self.extend(s);
135 Ok(())
136 }
137}
138
139impl Allocator for StackArena {
140 unsafe fn allocate(
141 &mut self,
142 layout: std::alloc::Layout,
143 ) -> Result<std::ptr::NonNull<[u8]>, crate::AllocError> {
144 debug_assert_eq!(self.current.len, 0, "use grow instead");
145 let mut offset = self.current.offset
146 + self
147 .current
148 .buffer
149 .as_ptr()
150 .add(self.current.offset)
151 .align_offset(layout.align());
152 if offset + layout.size() > self.current.capacity() {
153 let mut capacity = self.current.capacity();
154 while capacity < layout.size() {
155 capacity *= 2;
156 }
157 let old_chunk = std::mem::replace(&mut self.current, Chunk::new(capacity));
158 if old_chunk.offset != 0 {
159 self.store.push(old_chunk.into());
160 }
161 offset = self.current.buffer.as_ptr().align_offset(layout.align());
162 }
163 self.current.offset = offset;
164 self.current.len = layout.size();
165 Ok(self.current.object())
166 }
167
168 unsafe fn deallocate(&mut self, ptr: std::ptr::NonNull<u8>, layout: std::alloc::Layout) {
169 let object = NonNull::slice_from_raw_parts(ptr, layout.size());
170 while let Some(top) = self.stack.pop() {
171 if std::ptr::eq(object.as_ptr(), top.as_ptr()) {
173 break;
174 }
175 }
176 }
177
178 unsafe fn grow(
179 &mut self,
180 ptr: NonNull<u8>,
181 old_layout: std::alloc::Layout,
182 new_layout: std::alloc::Layout,
183 ) -> Result<NonNull<[u8]>, crate::AllocError> {
184 match old_layout.size().cmp(&new_layout.size()) {
185 std::cmp::Ordering::Less => {
186 debug_assert_eq!(
187 ptr.as_ptr().cast_const(),
188 self.current.buffer.as_ptr().add(self.current.offset)
189 );
190 debug_assert_eq!(old_layout.size(), self.current.len);
191 debug_assert_eq!(old_layout.align(), new_layout.align());
192
193 let mut capacity = self.current.capacity();
194 if self.current.offset + new_layout.size() > capacity {
195 while capacity < new_layout.size() {
196 capacity *= 2;
197 }
198 let old_chunk = std::mem::replace(&mut self.current, Chunk::new(capacity));
199 let object = old_chunk.object();
200 self.current.offset = self
201 .current
202 .buffer
203 .as_ptr()
204 .align_offset(new_layout.align());
205 self.current
206 .buffer
207 .as_mut_ptr()
208 .add(self.current.offset)
209 .copy_from_nonoverlapping(object.as_ptr().cast(), object.len());
210 if old_chunk.offset != 0 {
211 self.store.push(old_chunk.into());
212 }
213 }
214 self.current.len = new_layout.size();
215 Ok(self.current.object())
216 }
217 std::cmp::Ordering::Equal => Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size())),
218 std::cmp::Ordering::Greater => panic!("use shrink instead"),
219 }
220 }
221
222 unsafe fn shrink(
223 &mut self,
224 ptr: NonNull<u8>,
225 old_layout: Layout,
226 new_layout: Layout,
227 ) -> Result<NonNull<[u8]>, crate::AllocError> {
228 match old_layout.size().cmp(&new_layout.size()) {
229 std::cmp::Ordering::Less => panic!("use grow instead"),
230 std::cmp::Ordering::Equal => Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size())),
231 std::cmp::Ordering::Greater => {
232 debug_assert_eq!(
233 ptr.as_ptr().cast_const(),
234 self.current.buffer.as_ptr().add(self.current.offset)
235 );
236 debug_assert_eq!(old_layout.size(), self.current.len);
237 debug_assert_eq!(old_layout.align(), new_layout.align());
238 self.current.len = new_layout.size();
239 Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size()))
240 }
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use std::fmt::Write;
248
249 use super::*;
250
251 #[test]
252 fn test_lifecycle() {
253 let mut stack = StackArena::new();
254 write!(&mut stack, "ab").expect("write");
255 let s = "c";
256 stack.extend(s);
257 let p = unsafe { stack.finish().as_ref() };
258 assert_eq!(p, b"abc");
259 }
260
261 #[test]
262 fn test_new() {
263 let stack = StackArena::new();
264 assert_eq!(stack.len(), 0);
265 assert!(stack.is_empty());
266 assert_eq!(stack.store.len(), 0);
267 assert_eq!(stack.stack.len(), 0);
268 }
269
270 #[test]
271 fn test_push_pop() {
272 let mut stack = StackArena::new();
273
274 stack.push(b"hello");
276 assert_eq!(stack.len(), 1);
277 assert!(!stack.is_empty());
278
279 stack.push(b"world");
281 assert_eq!(stack.len(), 2);
282
283 stack.pop();
285 assert_eq!(stack.len(), 1);
286
287 stack.pop();
289 assert_eq!(stack.len(), 0);
290 assert!(stack.is_empty());
291 }
292
293 #[test]
294 fn test_extend_small_data() {
295 let mut stack = StackArena::new();
296
297 stack.extend(b"hello");
299
300 stack.extend(b" world");
302
303 let data = unsafe { stack.finish().as_ref() };
305 assert_eq!(data, b"hello world");
306 }
307
308 #[test]
309 fn test_extend_large_data() {
310 let mut stack = StackArena::new();
311
312 let large_data = vec![b'x'; 20];
314 stack.extend(&large_data);
315
316 let data = unsafe { stack.finish().as_ref() };
318 assert_eq!(data, &large_data[..]);
319 }
320
321 #[test]
322 fn test_extend_after_finish() {
323 let mut stack = StackArena::new();
324
325 stack.extend("first");
327 let first = unsafe { stack.finish().as_ref() };
328 assert_eq!(first, b"first");
329 eprintln!("{stack:?}");
330 stack.extend(b"second");
332 let second = unsafe { stack.finish().as_ref() };
333 eprintln!("{stack:?}");
334
335 assert_eq!(second, b"second");
336
337 assert_eq!(first, b"first");
339 assert_eq!(second, b"second");
340 }
341
342 #[test]
343 fn test_free() {
344 let mut stack = StackArena::new();
345
346 stack.push(b"first");
348 stack.extend(b"second");
349 let second = unsafe { stack.finish().as_ref() };
350 stack.extend(b"third");
351 let _third = unsafe { stack.finish().as_ref() };
352
353 stack.free(second);
355
356 assert_eq!(stack.len(), 1); stack.extend(b"fourth");
361 let fourth = unsafe { stack.finish().as_ref() };
362 assert_eq!(fourth, b"fourth");
363
364 assert_eq!(stack.len(), 2); }
367
368 #[test]
369 fn test_write_trait() {
370 let mut stack = StackArena::new();
371
372 write!(&mut stack, "Hello, {}!", "world").unwrap();
374 let data = unsafe { stack.finish().as_ref() };
376 assert_eq!(data, b"Hello, world!");
377 }
378
379 #[test]
380 fn test_empty_data() {
381 let mut stack = StackArena::new();
382
383 stack.extend(b"");
385 let data = unsafe { stack.finish().as_ref() };
386 assert_eq!(data, b"");
387
388 stack.push(b"");
390 assert_eq!(stack.len(), 2);
391 }
392
393 #[test]
394 fn test_multiple_operations() {
395 let mut stack = StackArena::new();
396
397 stack.push(b"item1");
399 stack.extend(b"item2-part1");
400 stack.extend(b"-part2");
401 let item2 = unsafe { stack.finish().as_ref() };
402 write!(&mut stack, "item3").unwrap();
403 let item3 = unsafe { stack.finish().as_ref() };
404
405 assert_eq!(item2, b"item2-part1-part2");
407 assert_eq!(item3, b"item3");
408 assert_eq!(stack.len(), 3);
409
410 stack.pop();
412 assert_eq!(stack.len(), 2);
413 }
414
415 #[test]
416 fn test_extend_exact_capacity() {
417 let mut stack = StackArena::new();
418
419 let data = vec![b'x'; 10]; stack.extend(&data);
422
423 stack.extend(b"more");
425
426 let result = unsafe { stack.finish().as_ref() };
428 let mut expected = data.clone();
429 expected.extend_from_slice(b"more");
430 assert_eq!(result, expected.as_slice());
431 }
432
433 #[test]
434 fn test_free_all() {
435 let mut stack = StackArena::new();
436
437 stack.push(b"first");
439 stack.extend(b"second");
440 let _second = unsafe { stack.finish().as_ref() };
441
442 let first = unsafe { stack.stack[0].as_ref() };
444 stack.free(first);
445
446 assert_eq!(stack.len(), 0);
448 assert!(stack.is_empty());
449 }
450
451 #[test]
452 #[should_panic]
453 fn test_free_nonexistent() {
454 let mut stack = StackArena::new();
455
456 stack.push(b"object");
458 assert_eq!(stack.len(), 1);
459
460 let dummy = b"nonexistent";
462 stack.free(dummy);
463
464 assert_eq!(stack.len(), 1);
466 }
467}