1use core::alloc::{GlobalAlloc, Layout};
2
3pub unsafe trait ChainableAlloc {
10 fn claims(&self, ptr: *mut u8, layout: Layout) -> bool;
14}
15
16pub struct AllocChain<'a, A, B>(A, &'a B);
32
33impl<'a, A, B> AllocChain<'a, A, B> {
34 pub const fn new(a: A, b: &'a B) -> Self {
36 Self(a, b)
37 }
38
39 pub const fn chain<T>(self, next: &T) -> AllocChain<'_, Self, T>
41 where
42 Self: Sized,
43 {
44 AllocChain::new(self, next)
45 }
46}
47
48unsafe impl<A: GlobalAlloc + ChainableAlloc, B: GlobalAlloc> GlobalAlloc for AllocChain<'_, A, B> {
49 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
50 let ptr_a = unsafe { self.0.alloc(layout) };
51 if ptr_a.is_null() {
52 unsafe { self.1.alloc(layout) }
53 } else {
54 ptr_a
55 }
56 }
57
58 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
59 if self.0.claims(ptr, layout) {
60 unsafe { self.0.dealloc(ptr, layout) };
61 } else {
62 unsafe { self.1.dealloc(ptr, layout) };
63 }
64 }
65
66 unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
67 if self.0.claims(ptr, layout) {
68 let ptr_a = unsafe { self.0.realloc(ptr, layout, new_size) };
69 if !ptr_a.is_null() {
70 return ptr_a;
71 }
72
73 let layout_b = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
74 let ptr_b = unsafe { self.1.alloc(layout_b) };
75
76 if !ptr_b.is_null() {
77 unsafe {
79 ptr.copy_to_nonoverlapping(ptr_b, layout.size());
80 self.0.dealloc(ptr, layout);
81 }
82 }
83
84 ptr_b
86 } else {
87 unsafe { self.1.realloc(ptr, layout, new_size) }
88 }
90 }
91}
92
93#[cfg(any(feature = "allocator-api", feature = "allocator-api2"))]
94use {
95 crate::{AllocError, Allocator},
96 core::ptr::NonNull,
97};
98
99#[cfg(any(feature = "allocator-api", feature = "allocator-api2"))]
100unsafe impl<A: ChainableAlloc, B> Allocator for &AllocChain<'_, A, B>
101where
102 for<'a> &'a A: Allocator,
103 for<'a> &'a B: Allocator,
104{
105 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
106 (&self.0)
107 .allocate(layout)
108 .or_else(|_| self.1.allocate(layout))
109 }
110
111 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
112 if self.0.claims(ptr.as_ptr(), layout) {
113 unsafe { (&self.0).deallocate(ptr, layout) };
114 } else {
115 unsafe { self.1.deallocate(ptr, layout) }
116 }
117 }
118
119 unsafe fn grow(
120 &self,
121 ptr: NonNull<u8>,
122 old_layout: Layout,
123 new_layout: Layout,
124 ) -> Result<NonNull<[u8]>, AllocError> {
125 if self.0.claims(ptr.as_ptr(), old_layout) {
126 let res_a = unsafe { (&self.0).grow(ptr, old_layout, new_layout) };
127 if res_a.is_ok() {
128 return res_a;
129 }
130
131 let res_b = self.1.allocate(new_layout);
132 if let Ok(ptr_b) = res_b {
133 unsafe {
135 ptr.copy_to_nonoverlapping(ptr_b.cast(), old_layout.size());
136 (&self.0).deallocate(ptr, old_layout);
137 }
138 }
139
140 res_b
141 } else {
142 unsafe { self.1.grow(ptr, old_layout, new_layout) }
143 }
145 }
146
147 unsafe fn grow_zeroed(
148 &self,
149 ptr: NonNull<u8>,
150 old_layout: Layout,
151 new_layout: Layout,
152 ) -> Result<NonNull<[u8]>, AllocError> {
153 unsafe {
154 let new_ptr = self.grow(ptr, old_layout, new_layout)?;
156 let count = new_ptr.len() - old_layout.size();
157
158 new_ptr
160 .cast::<u8>()
161 .add(old_layout.size())
162 .write_bytes(0, count);
163
164 Ok(new_ptr)
165 }
166 }
167
168 unsafe fn shrink(
169 &self,
170 ptr: NonNull<u8>,
171 old_layout: Layout,
172 new_layout: Layout,
173 ) -> Result<NonNull<[u8]>, AllocError> {
174 if self.0.claims(ptr.as_ptr(), old_layout) {
175 let res_a = unsafe { (&self.0).shrink(ptr, old_layout, new_layout) };
176 if res_a.is_ok() {
177 return res_a;
178 }
179
180 let res_b = self.1.allocate(new_layout);
181 if let Ok(ptr_b) = res_b {
182 unsafe {
184 ptr.copy_to_nonoverlapping(ptr_b.cast(), old_layout.size());
185 (&self.0).deallocate(ptr, old_layout);
186 }
187 }
188
189 res_b
190 } else {
191 unsafe { self.1.shrink(ptr, old_layout, new_layout) }
192 }
194 }
195
196 fn by_ref(&self) -> &Self
197 where
198 Self: Sized,
199 {
200 self
201 }
202}