1#![deny(missing_docs)]
52
53mod lazy_atomic_cell;
54
55cfg_if::cfg_if! {
56 if #[cfg(unix)] {
57 mod pthread_mutex;
58 use pthread_mutex::PthreadMutex as Mutex;
59 } else if #[cfg(windows)] {
60 mod windows_mutex;
61 use windows_mutex::WindowsMutex as Mutex;
62 } else {
63 compile_error!("no mutex implementation for this platform");
64 }
65}
66
67#[doc(hidden)]
70pub use lazy_atomic_cell::LazyAtomicCell;
71
72use mem::MaybeUninit;
73use rand::{rngs::StdRng, Rng, SeedableRng};
74use std::{
75 alloc::{handle_alloc_error, GlobalAlloc, Layout},
76 mem, ptr,
77 sync::atomic::{AtomicPtr, Ordering},
78};
79
80const SHUFFLING_ARRAY_SIZE: usize = 256;
81
82struct ShufflingArray<A>
83where
84 A: 'static + GlobalAlloc,
85{
86 elems: [AtomicPtr<u8>; SHUFFLING_ARRAY_SIZE],
87 size_class: usize,
88 allocator: &'static A,
89}
90
91impl<A> Drop for ShufflingArray<A>
92where
93 A: 'static + GlobalAlloc,
94{
95 fn drop(&mut self) {
96 let layout =
97 unsafe { Layout::from_size_align_unchecked(self.size_class, mem::align_of::<usize>()) };
98 for el in &self.elems {
99 let p = el.swap(ptr::null_mut(), Ordering::SeqCst);
100 if !p.is_null() {
101 unsafe {
102 self.allocator.dealloc(p, layout);
103 }
104 }
105 }
106 }
107}
108
109impl<A> ShufflingArray<A>
110where
111 A: 'static + GlobalAlloc,
112{
113 fn new(size_class: usize, allocator: &'static A) -> Self {
114 let elems = unsafe {
115 let mut elems = MaybeUninit::<[AtomicPtr<u8>; 256]>::uninit();
116 let elems_ptr: *mut [AtomicPtr<u8>; 256] = elems.as_mut_ptr();
117 let elems_ptr: *mut AtomicPtr<u8> = elems_ptr.cast();
118 let layout = Layout::from_size_align_unchecked(size_class, mem::align_of::<usize>());
119 for i in 0..256 {
120 let p = allocator.alloc(layout);
121 if p.is_null() {
122 handle_alloc_error(layout);
123 }
124 ptr::write(elems_ptr.offset(i), AtomicPtr::new(p));
125 }
126 elems.assume_init()
127 };
128 ShufflingArray {
129 elems,
130 size_class,
131 allocator,
132 }
133 }
134
135 fn elem_layout(&self) -> Layout {
138 unsafe {
139 debug_assert!(
140 Layout::from_size_align(self.size_class, mem::align_of::<usize>()).is_ok()
141 );
142 Layout::from_size_align_unchecked(self.size_class, mem::align_of::<usize>())
143 }
144 }
145}
146
147struct SizeClasses<A>([LazyAtomicCell<A, ShufflingArray<A>>; NUM_SIZE_CLASSES])
148where
149 A: 'static + GlobalAlloc;
150
151struct SizeClassInfo {
152 index: usize,
153 size_class: usize,
154}
155
156#[rustfmt::skip]
157#[inline]
158fn size_class_info(size: usize) -> Option<SizeClassInfo> {
159 let mut size_class = mem::size_of::<usize>();
160 let mut stride = mem::size_of::<usize>();
161
162 if size <= size_class {
163 return Some(SizeClassInfo { index: 0, size_class });
164 }
165 size_class += stride;
166 if size <= size_class {
167 return Some(SizeClassInfo { index: 1, size_class });
168 }
169 size_class += stride;
170 if size <= size_class {
171 return Some(SizeClassInfo { index: 2, size_class });
172 }
173 size_class += stride;
174 if size <= size_class {
175 return Some(SizeClassInfo { index: 3, size_class });
176 }
177 size_class += stride;
178
179 stride = stride * 2;
180
181 if size <= size_class {
182 return Some(SizeClassInfo { index: 4, size_class });
183 }
184 size_class += stride;
185 if size <= size_class {
186 return Some(SizeClassInfo { index: 5, size_class });
187 }
188 size_class += stride;
189 if size <= size_class {
190 return Some(SizeClassInfo { index: 6, size_class });
191 }
192 size_class += stride;
193 if size <= size_class {
194 return Some(SizeClassInfo { index: 7, size_class });
195 }
196 size_class += stride;
197
198 stride = stride * 2;
199
200 if size <= size_class {
201 return Some(SizeClassInfo { index: 8, size_class });
202 }
203 size_class += stride;
204 if size <= size_class {
205 return Some(SizeClassInfo { index: 9, size_class });
206 }
207 size_class += stride;
208 if size <= size_class {
209 return Some(SizeClassInfo { index: 10, size_class });
210 }
211 size_class += stride;
212 if size <= size_class {
213 return Some(SizeClassInfo { index: 11, size_class });
214 }
215 size_class += stride;
216
217 stride = stride * 2;
218
219 if size <= size_class {
220 return Some(SizeClassInfo { index: 12, size_class });
221 }
222 size_class += stride;
223 if size <= size_class {
224 return Some(SizeClassInfo { index: 13, size_class });
225 }
226 size_class += stride;
227 if size <= size_class {
228 return Some(SizeClassInfo { index: 14, size_class });
229 }
230 size_class += stride;
231 if size <= size_class {
232 return Some(SizeClassInfo { index: 15, size_class });
233 }
234 size_class += stride;
235
236 stride = stride * 2;
237
238 if size <= size_class {
239 return Some(SizeClassInfo { index: 16, size_class });
240 }
241 size_class += stride;
242 if size <= size_class {
243 return Some(SizeClassInfo { index: 17, size_class });
244 }
245 size_class += stride;
246 if size <= size_class {
247 return Some(SizeClassInfo { index: 18, size_class });
248 }
249 size_class += stride;
250 if size <= size_class {
251 return Some(SizeClassInfo { index: 19, size_class });
252 }
253 size_class += stride;
254
255 stride = stride * 2;
256
257 if size <= size_class {
258 return Some(SizeClassInfo { index: 20, size_class });
259 }
260 size_class += stride;
261 if size <= size_class {
262 return Some(SizeClassInfo { index: 21, size_class });
263 }
264 size_class += stride;
265 if size <= size_class {
266 return Some(SizeClassInfo { index: 22, size_class });
267 }
268 size_class += stride;
269 if size <= size_class {
270 return Some(SizeClassInfo { index: 23, size_class });
271 }
272 size_class += stride;
273
274 stride = stride * 2;
275
276 if size <= size_class {
277 return Some(SizeClassInfo { index: 24, size_class });
278 }
279 size_class += stride;
280 if size <= size_class {
281 return Some(SizeClassInfo { index: 25, size_class });
282 }
283 size_class += stride;
284 if size <= size_class {
285 return Some(SizeClassInfo { index: 26, size_class });
286 }
287 size_class += stride;
288 if size <= size_class {
289 return Some(SizeClassInfo { index: 27, size_class });
290 }
291 size_class += stride;
292
293 stride = stride * 2;
294
295 if size <= size_class {
296 return Some(SizeClassInfo { index: 28, size_class });
297 }
298 size_class += stride;
299 if size <= size_class {
300 return Some(SizeClassInfo { index: 29, size_class });
301 }
302 size_class += stride;
303 if size <= size_class {
304 return Some(SizeClassInfo { index: 30, size_class });
305 }
306 size_class += stride;
307 if size <= size_class {
308 return Some(SizeClassInfo { index: 31, size_class });
309 }
310
311 None
312}
313
314const NUM_SIZE_CLASSES: usize = 32;
315
316pub struct ShufflingAllocator<A>
332where
333 A: 'static + GlobalAlloc,
334{
335 #[doc(hidden)]
339 pub inner: &'static A,
340 #[doc(hidden)]
341 pub state: LazyAtomicCell<A, State<A>>,
342}
343
344#[doc(hidden)]
345pub struct State<A>
346where
347 A: 'static + GlobalAlloc,
348{
349 rng: Mutex<A, StdRng>,
350 size_classes: LazyAtomicCell<A, SizeClasses<A>>,
351}
352
353#[macro_export]
364macro_rules! wrap {
365 ($inner:expr) => {
366 $crate::ShufflingAllocator {
367 inner: $inner,
368 state: $crate::LazyAtomicCell {
369 ptr: ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut()),
370 allocator: $inner,
371 },
372 }
373 };
374}
375
376impl<A> ShufflingAllocator<A>
377where
378 A: 'static + GlobalAlloc,
379{
380 #[inline]
390 fn state(&self) -> &State<A> {
391 self.state.get_or_create(|| State {
392 rng: Mutex::new(&self.inner, StdRng::from_entropy()),
393 size_classes: LazyAtomicCell::new(self.inner),
394 })
395 }
396
397 #[inline]
398 fn random_index(&self) -> usize {
399 let mut rng = self.state().rng.lock();
400 rng.gen_range(0..SHUFFLING_ARRAY_SIZE)
401 }
402
403 #[inline]
404 fn size_classes(&self) -> &SizeClasses<A> {
405 self.state().size_classes.get_or_create(|| {
406 let mut classes =
407 MaybeUninit::<[LazyAtomicCell<A, ShufflingArray<A>>; NUM_SIZE_CLASSES]>::uninit();
408 unsafe {
409 for i in 0..NUM_SIZE_CLASSES {
410 ptr::write(
411 classes
412 .as_mut_ptr()
413 .cast::<LazyAtomicCell<A, ShufflingArray<A>>>()
414 .offset(i as _),
415 LazyAtomicCell::new(self.inner),
416 );
417 }
418 SizeClasses(classes.assume_init())
419 }
420 })
421 }
422
423 #[inline]
424 fn shuffling_array(&self, size: usize) -> Option<&ShufflingArray<A>> {
425 let SizeClassInfo { index, size_class } = size_class_info(size)?;
426 let size_classes = self.size_classes();
427 Some(size_classes.0[index].get_or_create(|| ShufflingArray::new(size_class, self.inner)))
428 }
429}
430
431unsafe impl<A> GlobalAlloc for ShufflingAllocator<A>
432where
433 A: GlobalAlloc,
434{
435 #[inline]
436 unsafe fn alloc(&self, layout: std::alloc::Layout) -> *mut u8 {
437 if layout.align() > mem::align_of::<usize>() {
439 return self.inner.alloc(layout);
440 }
441
442 match self.shuffling_array(layout.size()) {
443 None => self.inner.alloc(layout),
446
447 Some(array) => {
450 let replacement_ptr = self.inner.alloc(array.elem_layout());
451 if replacement_ptr.is_null() {
452 return ptr::null_mut();
453 }
454
455 let index = self.random_index();
456 array.elems[index].swap(replacement_ptr, Ordering::SeqCst)
457 }
458 }
459 }
460
461 #[inline]
462 unsafe fn dealloc(&self, ptr: *mut u8, layout: std::alloc::Layout) {
463 if ptr.is_null() {
464 return;
465 }
466
467 if layout.align() > mem::align_of::<usize>() {
468 self.inner.dealloc(ptr, layout);
469 return;
470 }
471
472 match self.shuffling_array(layout.size()) {
473 None => self.inner.dealloc(ptr, layout),
475
476 Some(array) => {
479 let index = self.random_index();
480 let old_ptr = array.elems[index].swap(ptr, Ordering::SeqCst);
481 self.inner.dealloc(old_ptr, array.elem_layout());
482 }
483 }
484 }
485}