stalloc/
chain.rs

1use core::alloc::{GlobalAlloc, Layout};
2
3/// A trait representing an allocator that another allocator can be chained to.
4///
5/// # Safety
6/// `claims` must return true if and only if the allocation belongs to this
7/// allocator (as opposed to one further up the chain). This trait is used to decide
8/// which allocator to free from when the user calls `deallocate()` and related functions.
9pub unsafe trait ChainableAlloc {
10	/// Tests whether a certain allocation belongs to this allocator. This
11	/// is called when using `deallocate()` and related functions in order to
12	/// determine which allocator needs to free the pointer.
13	fn claims(&self, ptr: *mut u8, layout: Layout) -> bool;
14}
15
16/// A chain of allocators. If the first allocator is exhuasted, the second one is used as a fallback.
17///
18/// # Examples
19/// ```
20/// // If the `SyncStalloc` is full, fall back to the system allocator.
21/// use stalloc::{SyncStalloc, Stalloc};
22/// use std::alloc::System;
23///
24/// let alloc_with_fallback = SyncStalloc::<1024, 8>::new().chain(&System);
25///
26/// let crazy_chain = Stalloc::<128, 4>::new()
27///     .chain(&Stalloc::<1024, 8>::new())
28///     .chain(&Stalloc::<8192, 16>::new())
29///     .chain(&System);
30/// ```
31pub struct AllocChain<'a, A, B>(A, &'a B);
32
33impl<'a, A, B> AllocChain<'a, A, B> {
34	/// Initializes a new `AllocChain`.
35	pub const fn new(a: A, b: &'a B) -> Self {
36		Self(a, b)
37	}
38
39	/// Creates a new `AllocChain` containing this chain and `next`.
40	pub const fn chain<T>(self, next: &T) -> AllocChain<'_, Self, T>
41	where
42		Self: Sized,
43	{
44		AllocChain::new(self, next)
45	}
46}
47
48unsafe impl<A: GlobalAlloc + ChainableAlloc, B: GlobalAlloc> GlobalAlloc for AllocChain<'_, A, B> {
49	unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
50		let ptr_a = unsafe { self.0.alloc(layout) };
51		if ptr_a.is_null() {
52			unsafe { self.1.alloc(layout) }
53		} else {
54			ptr_a
55		}
56	}
57
58	unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
59		if self.0.claims(ptr, layout) {
60			unsafe { self.0.dealloc(ptr, layout) };
61		} else {
62			unsafe { self.1.dealloc(ptr, layout) };
63		}
64	}
65
66	unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
67		if self.0.claims(ptr, layout) {
68			let ptr_a = unsafe { self.0.realloc(ptr, layout, new_size) };
69			if !ptr_a.is_null() {
70				return ptr_a;
71			}
72
73			let layout_b = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
74			let ptr_b = unsafe { self.1.alloc(layout_b) };
75
76			if !ptr_b.is_null() {
77				// Copy the allocation from `A` to `B`.
78				unsafe {
79					ptr.copy_to_nonoverlapping(ptr_b, layout.size());
80					self.0.dealloc(ptr, layout);
81				}
82			}
83
84			// This is either a valid pointer or null.
85			ptr_b
86		} else {
87			unsafe { self.1.realloc(ptr, layout, new_size) }
88			// Don't fall back to `A`.
89		}
90	}
91}
92
93#[cfg(any(feature = "allocator-api", feature = "allocator-api2"))]
94use {
95	crate::{AllocError, Allocator},
96	core::ptr::NonNull,
97};
98
99#[cfg(any(feature = "allocator-api", feature = "allocator-api2"))]
100unsafe impl<A: ChainableAlloc, B> Allocator for &AllocChain<'_, A, B>
101where
102	for<'a> &'a A: Allocator,
103	for<'a> &'a B: Allocator,
104{
105	fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
106		(&self.0)
107			.allocate(layout)
108			.or_else(|_| self.1.allocate(layout))
109	}
110
111	unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
112		if self.0.claims(ptr.as_ptr(), layout) {
113			unsafe { (&self.0).deallocate(ptr, layout) };
114		} else {
115			unsafe { self.1.deallocate(ptr, layout) }
116		}
117	}
118
119	unsafe fn grow(
120		&self,
121		ptr: NonNull<u8>,
122		old_layout: Layout,
123		new_layout: Layout,
124	) -> Result<NonNull<[u8]>, AllocError> {
125		if self.0.claims(ptr.as_ptr(), old_layout) {
126			let res_a = unsafe { (&self.0).grow(ptr, old_layout, new_layout) };
127			if res_a.is_ok() {
128				return res_a;
129			}
130
131			let res_b = self.1.allocate(new_layout);
132			if let Ok(ptr_b) = res_b {
133				// Copy the allocation from `A` to `B`.
134				unsafe {
135					ptr.copy_to_nonoverlapping(ptr_b.cast(), old_layout.size());
136					(&self.0).deallocate(ptr, old_layout);
137				}
138			}
139
140			res_b
141		} else {
142			unsafe { self.1.grow(ptr, old_layout, new_layout) }
143			// Don't fall back to `A`.
144		}
145	}
146
147	unsafe fn grow_zeroed(
148		&self,
149		ptr: NonNull<u8>,
150		old_layout: Layout,
151		new_layout: Layout,
152	) -> Result<NonNull<[u8]>, AllocError> {
153		unsafe {
154			// SAFETY: Upheld by the caller.
155			let new_ptr = self.grow(ptr, old_layout, new_layout)?;
156			let count = new_ptr.len() - old_layout.size();
157
158			// SAFETY: We are filling in the extra capacity with zeros.
159			new_ptr
160				.cast::<u8>()
161				.add(old_layout.size())
162				.write_bytes(0, count);
163
164			Ok(new_ptr)
165		}
166	}
167
168	unsafe fn shrink(
169		&self,
170		ptr: NonNull<u8>,
171		old_layout: Layout,
172		new_layout: Layout,
173	) -> Result<NonNull<[u8]>, AllocError> {
174		if self.0.claims(ptr.as_ptr(), old_layout) {
175			let res_a = unsafe { (&self.0).shrink(ptr, old_layout, new_layout) };
176			if res_a.is_ok() {
177				return res_a;
178			}
179
180			let res_b = self.1.allocate(new_layout);
181			if let Ok(ptr_b) = res_b {
182				// Copy the allocation from `A` to `B`.
183				unsafe {
184					ptr.copy_to_nonoverlapping(ptr_b.cast(), old_layout.size());
185					(&self.0).deallocate(ptr, old_layout);
186				}
187			}
188
189			res_b
190		} else {
191			unsafe { self.1.shrink(ptr, old_layout, new_layout) }
192			// Don't fall back to `A`.
193		}
194	}
195
196	fn by_ref(&self) -> &Self
197	where
198		Self: Sized,
199	{
200		self
201	}
202}