1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![doc = include_str!("../README.md")]
4#![cfg_attr(nightly, feature(allocator_api))]
5use core::cell::UnsafeCell;
6use core::mem::MaybeUninit;
7use core::ptr::NonNull;
8use core::sync::atomic::AtomicUsize;
9use core::sync::atomic::Ordering;
10
11extern crate alloc;
12
13#[cfg(not(nightly))]
14use allocator_api2::alloc::{AllocError, Allocator, Layout};
15#[cfg(nightly)]
16use core::alloc::{AllocError, Allocator, Layout};
17
18pub struct StackAllocator<const N: usize> {
24 buf: UnsafeCell<MaybeUninit<[u8; N]>>,
26 offset: AtomicUsize,
28}
29
30impl<const N: usize> Default for StackAllocator<N> {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36unsafe impl<const N: usize> Send for StackAllocator<N> {}
37unsafe impl<const N: usize> Sync for StackAllocator<N> {}
38
39impl<const N: usize> StackAllocator<N> {
40 pub const fn new() -> Self {
42 Self {
43 buf: UnsafeCell::new(MaybeUninit::uninit()),
44 offset: AtomicUsize::new(0),
45 }
46 }
47
48 pub unsafe fn reset(&mut self) {
54 self.offset.store(0, Ordering::Release);
55 }
56
57 #[inline]
59 const fn align_up(addr: usize, align: usize) -> usize {
60 (addr + align - 1) & !(align - 1)
61 }
62
63 pub fn current_offset(&self) -> usize {
65 self.offset.load(Ordering::Acquire)
66 }
67}
68
69unsafe impl<const N: usize> Allocator for StackAllocator<N> {
70 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
71 let mut start;
73 loop {
74 start = Self::align_up(self.offset.load(Ordering::Acquire), layout.align());
76 let end = start.checked_add(layout.size()).ok_or(AllocError)?;
77
78 if end > N {
80 return Err(AllocError);
81 }
82
83 if self
85 .offset
86 .compare_exchange(start, end, Ordering::Release, Ordering::Relaxed)
87 .is_ok()
88 {
89 break;
90 }
91 }
92 let ptr = unsafe { self.buf.get().cast::<u8>().add(start) };
94 Ok(NonNull::slice_from_raw_parts(
95 NonNull::new(ptr).unwrap(),
96 layout.size(),
97 ))
98 }
99
100 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
101 let base = self.buf.get() as usize;
104 let start = ptr.as_ptr() as usize - base;
105 let end = start + layout.size();
106
107 let _ = self
109 .offset
110 .compare_exchange(end, start, Ordering::Release, Ordering::Relaxed);
111 }
112
113 unsafe fn grow(
114 &self,
115 ptr: NonNull<u8>,
116 old_layout: Layout,
117 new_layout: Layout,
118 ) -> Result<NonNull<[u8]>, AllocError> {
119 let base = self.buf.get() as usize;
126 let old_start = ptr.as_ptr() as usize - base;
127
128 let expected_offset = old_start + old_layout.size();
130 let current_offset = self.offset.load(Ordering::Acquire);
131 if current_offset != expected_offset {
132 return Err(AllocError);
133 }
134
135 if new_layout.size() < old_layout.size() {
137 return Err(AllocError);
138 }
139
140 let new_end = old_start.checked_add(new_layout.size()).ok_or(AllocError)?;
142 if new_end > N {
143 return Err(AllocError);
144 }
145
146 if self
149 .offset
150 .compare_exchange(
151 expected_offset,
152 new_end,
153 Ordering::Release,
154 Ordering::Relaxed,
155 )
156 .is_err()
157 {
158 return self.allocate(new_layout);
160 }
161 Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size()))
163 }
164
165 unsafe fn shrink(
166 &self,
167 ptr: NonNull<u8>,
168 old_layout: Layout,
169 new_layout: Layout,
170 ) -> Result<NonNull<[u8]>, AllocError> {
171 let base = self.buf.get() as usize;
173 let old_start = ptr.as_ptr() as usize - base;
174
175 let expected_offset = old_start + old_layout.size();
177 let current_offset = self.offset.load(Ordering::Acquire);
178 if current_offset != expected_offset {
179 return Err(AllocError);
180 }
181
182 if new_layout.size() > old_layout.size() {
184 return Err(AllocError);
185 }
186
187 let new_end = old_start + new_layout.size();
189
190 _ = self.offset.compare_exchange(
195 expected_offset,
196 new_end,
197 Ordering::Release,
198 Ordering::Relaxed,
199 );
200
201 Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size()))
203 }
204 fn by_ref(&self) -> &Self
205 where
206 Self: Sized,
207 {
208 self
209 }
210}
211
212pub struct HybridAllocator<const N: usize, F: Allocator> {
219 stack_alloc: StackAllocator<N>,
220 fallback: F,
221}
222
223#[cfg(feature = "alloc")]
224impl<const N: usize> Default for HybridAllocator<N, alloc::alloc::Global> {
225 fn default() -> Self {
226 Self::new(alloc::alloc::Global)
227 }
228}
229
230impl<const N: usize, F: Allocator> HybridAllocator<N, F> {
231 pub const fn new(fallback: F) -> Self {
236 Self {
237 stack_alloc: StackAllocator::new(),
238 fallback,
239 }
240 }
241
242 pub unsafe fn reset(&mut self) {
248 self.stack_alloc.reset();
249 }
250
251 pub fn current_offset(&self) -> usize {
253 self.stack_alloc.current_offset()
254 }
255
256 pub fn fallback(&self) -> &F {
258 &self.fallback
259 }
260
261 pub fn is_stack_exausted(&self) -> bool {
265 self.current_offset() >= N
266 }
267}
268
269unsafe impl<const N: usize, F: Allocator> Allocator for HybridAllocator<N, F> {
270 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
271 match self.stack_alloc.allocate(layout) {
273 ok @ Ok(_) => ok,
274 Err(_) => self.fallback.allocate(layout),
275 }
276 }
277
278 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
279 let base = self.stack_alloc.buf.get() as usize;
281 let end = base + N;
282 let addr = ptr.as_ptr() as usize;
283
284 if (base..end).contains(&addr) {
285 self.stack_alloc.deallocate(ptr, layout);
286 } else {
287 self.fallback.deallocate(ptr, layout);
288 }
289 }
290
291 unsafe fn grow(
292 &self,
293 ptr: NonNull<u8>,
294 old_layout: Layout,
295 new_layout: Layout,
296 ) -> Result<NonNull<[u8]>, AllocError> {
297 let base = self.stack_alloc.buf.get() as usize;
298 let addr = ptr.as_ptr() as usize;
299
300 if (base..base + N).contains(&addr) {
301 if let Ok(res) = self.stack_alloc.grow(ptr, old_layout, new_layout) {
303 return Ok(res);
304 } else {
305 let mut new_ptr = self.fallback.allocate(new_layout)?;
307 core::ptr::copy_nonoverlapping(
308 ptr.as_ptr(),
309 new_ptr.as_mut() as *mut [u8] as *mut u8,
310 old_layout.size(),
311 );
312 self.stack_alloc.deallocate(ptr, old_layout);
314 return Ok(new_ptr);
315 }
316 }
317 self.fallback.grow(ptr, old_layout, new_layout)
319 }
320
321 unsafe fn shrink(
322 &self,
323 ptr: NonNull<u8>,
324 old_layout: Layout,
325 new_layout: Layout,
326 ) -> Result<NonNull<[u8]>, AllocError> {
327 let base = self.stack_alloc.buf.get() as usize;
328 let addr = ptr.as_ptr() as usize;
329
330 if (base..base + N).contains(&addr) {
331 if let Ok(res) = self.stack_alloc.shrink(ptr, old_layout, new_layout) {
333 return Ok(res);
335 }
336 }
337 self.fallback.shrink(ptr, old_layout, new_layout)
339 }
340
341 fn by_ref(&self) -> &Self
342 where
343 Self: Sized,
344 {
345 self
346 }
347}