1use core::mem::{align_of, size_of};
2
3use crate::region::Region;
4use crate::slot::{SlotMeta, SlotState};
5use crate::sync::{AtomicU32, AtomicU64, Ordering, spin_loop};
6
7pub const FREE_LIST_END: u32 = u32::MAX;
9
10#[repr(C, align(64))]
12pub struct TreiberSlabHeader {
13 pub slot_size: u32,
14 pub slot_count: u32,
15 pub max_frame_size: u32,
16 _pad: u32,
17
18 pub free_head: AtomicU64,
20
21 pub slot_available: AtomicU32,
23
24 _pad2: [u8; 36],
25}
26
27#[cfg(not(loom))]
28const _: () = assert!(core::mem::size_of::<TreiberSlabHeader>() == 64);
29
30impl TreiberSlabHeader {
31 pub fn init(&mut self, slot_size: u32, slot_count: u32) {
32 self.slot_size = slot_size;
33 self.slot_count = slot_count;
34 self.max_frame_size = slot_size;
35 self._pad = 0;
36 self.free_head = AtomicU64::new(pack_free_head(FREE_LIST_END, 0));
37 self.slot_available = AtomicU32::new(0);
38 self._pad2 = [0; 36];
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub struct SlotHandle {
45 pub index: u32,
46 pub generation: u32,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum AllocResult {
52 Ok(SlotHandle),
53 WouldBlock,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SlotError {
59 InvalidIndex,
60 GenerationMismatch {
61 expected: u32,
62 actual: u32,
63 },
64 InvalidState {
65 expected: SlotState,
66 actual: SlotState,
67 },
68}
69
70pub type FreeError = SlotError;
71
72pub struct TreiberSlab {
77 #[allow(dead_code)]
80 region: Region,
81 inner: TreiberSlabRaw,
82}
83
84unsafe impl Send for TreiberSlab {}
85unsafe impl Sync for TreiberSlab {}
86
87impl TreiberSlab {
88 pub unsafe fn init(
94 region: Region,
95 header_offset: usize,
96 slot_count: u32,
97 slot_size: u32,
98 ) -> Self {
99 assert!(slot_count > 0, "slot_count must be > 0");
100 assert!(
101 slot_size >= size_of::<u32>() as u32,
102 "slot_size must be >= 4"
103 );
104 assert!(
105 header_offset.is_multiple_of(64),
106 "header_offset must be 64-byte aligned"
107 );
108
109 let meta_offset = align_up(
110 header_offset + size_of::<TreiberSlabHeader>(),
111 align_of::<SlotMeta>(),
112 );
113 let data_offset = align_up(
114 meta_offset + (slot_count as usize * size_of::<SlotMeta>()),
115 align_of::<u32>(),
116 );
117 let required = data_offset + (slot_count as usize * slot_size as usize);
118 assert!(required <= region.len(), "region too small for slab");
119
120 let header_ptr = region.offset(header_offset) as *mut TreiberSlabHeader;
122 let slot_meta_ptr = region.offset(meta_offset) as *mut SlotMeta;
123 let slot_data_ptr = region.offset(data_offset);
124
125 unsafe { (*header_ptr).init(slot_size, slot_count) };
127
128 for i in 0..slot_count {
130 let meta = unsafe { &mut *slot_meta_ptr.add(i as usize) };
131 meta.init();
132 }
133
134 let inner = unsafe { TreiberSlabRaw::from_raw(header_ptr, slot_meta_ptr, slot_data_ptr) };
136
137 unsafe { inner.init_free_list() };
139
140 Self { region, inner }
141 }
142
143 pub unsafe fn attach(region: Region, header_offset: usize) -> Result<Self, &'static str> {
149 assert!(
150 header_offset.is_multiple_of(64),
151 "header_offset must be 64-byte aligned"
152 );
153
154 let header_ptr = region.offset(header_offset) as *mut TreiberSlabHeader;
155 let header = unsafe { &*header_ptr };
156
157 if header.slot_count == 0 {
158 return Err("slot_count must be > 0");
159 }
160 if header.slot_size < size_of::<u32>() as u32 {
161 return Err("slot_size must be >= 4");
162 }
163
164 let meta_offset = align_up(
165 header_offset + size_of::<TreiberSlabHeader>(),
166 align_of::<SlotMeta>(),
167 );
168 let data_offset = align_up(
169 meta_offset + (header.slot_count as usize * size_of::<SlotMeta>()),
170 align_of::<u32>(),
171 );
172 let required = data_offset + (header.slot_count as usize * header.slot_size as usize);
173 if required > region.len() {
174 return Err("region too small for slab");
175 }
176
177 let slot_meta_ptr = region.offset(meta_offset) as *mut SlotMeta;
178 let slot_data_ptr = region.offset(data_offset);
179
180 let inner = unsafe { TreiberSlabRaw::from_raw(header_ptr, slot_meta_ptr, slot_data_ptr) };
181
182 Ok(Self { region, inner })
183 }
184
185 #[inline]
187 pub fn inner(&self) -> &TreiberSlabRaw {
188 &self.inner
189 }
190
191 pub fn try_alloc(&self) -> AllocResult {
195 self.inner.try_alloc()
196 }
197
198 pub fn mark_in_flight(&self, handle: SlotHandle) -> Result<(), SlotError> {
202 self.inner.mark_in_flight(handle)
203 }
204
205 pub fn free(&self, handle: SlotHandle) -> Result<(), SlotError> {
209 self.inner.free(handle)
210 }
211
212 pub fn free_allocated(&self, handle: SlotHandle) -> Result<(), SlotError> {
216 self.inner.free_allocated(handle)
217 }
218
219 pub unsafe fn slot_data_ptr(&self, handle: SlotHandle) -> *mut u8 {
225 unsafe { self.inner.slot_data_ptr(handle) }
226 }
227
228 #[inline]
230 pub fn slot_size(&self) -> u32 {
231 self.inner.slot_size()
232 }
233
234 #[inline]
236 pub fn slot_count(&self) -> u32 {
237 self.inner.slot_count()
238 }
239
240 pub fn free_count_approx(&self) -> u32 {
242 self.inner.free_count_approx()
243 }
244}
245
246#[inline]
247fn pack_free_head(index: u32, tag: u32) -> u64 {
248 ((tag as u64) << 32) | (index as u64)
249}
250
251#[inline]
252fn unpack_free_head(packed: u64) -> (u32, u32) {
253 let index = packed as u32;
254 let tag = (packed >> 32) as u32;
255 (index, tag)
256}
257
258#[inline]
259const fn align_up(value: usize, align: usize) -> usize {
260 (value + (align - 1)) & !(align - 1)
261}
262
263pub struct TreiberSlabRaw {
276 header: *mut TreiberSlabHeader,
277 slot_meta: *mut SlotMeta,
278 slot_data: *mut u8,
279}
280
281unsafe impl Send for TreiberSlabRaw {}
282unsafe impl Sync for TreiberSlabRaw {}
283
284impl TreiberSlabRaw {
285 #[inline]
295 pub unsafe fn from_raw(
296 header: *mut TreiberSlabHeader,
297 slot_meta: *mut SlotMeta,
298 slot_data: *mut u8,
299 ) -> Self {
300 Self {
301 header,
302 slot_meta,
303 slot_data,
304 }
305 }
306
307 #[inline]
308 fn header(&self) -> &TreiberSlabHeader {
309 unsafe { &*self.header }
310 }
311
312 #[inline]
313 unsafe fn meta(&self, index: u32) -> &SlotMeta {
314 unsafe { &*self.slot_meta.add(index as usize) }
315 }
316
317 #[inline]
318 unsafe fn data_ptr(&self, index: u32) -> *mut u8 {
319 let slot_size = self.header().slot_size as usize;
320 unsafe { self.slot_data.add(index as usize * slot_size) }
321 }
322
323 #[inline]
324 unsafe fn read_next_free(&self, index: u32) -> u32 {
325 let ptr = unsafe { self.data_ptr(index) as *const u32 };
326 unsafe { core::ptr::read_volatile(ptr) }
327 }
328
329 #[inline]
330 unsafe fn write_next_free(&self, index: u32, next: u32) {
331 let ptr = unsafe { self.data_ptr(index) as *mut u32 };
332 unsafe { core::ptr::write_volatile(ptr, next) };
333 }
334
335 pub unsafe fn init_free_list(&self) {
341 let slot_count = self.header().slot_count;
342 if slot_count == 0 {
343 return;
344 }
345
346 for i in 0..slot_count - 1 {
347 unsafe { self.write_next_free(i, i + 1) };
348 }
349 unsafe { self.write_next_free(slot_count - 1, FREE_LIST_END) };
350
351 let header = unsafe { &mut *self.header };
352 header
353 .free_head
354 .store(pack_free_head(0, 0), Ordering::Release);
355 }
356
357 pub fn try_alloc(&self) -> AllocResult {
359 let header = self.header();
360
361 loop {
362 let old_head = header.free_head.load(Ordering::Acquire);
363 let (index, tag) = unpack_free_head(old_head);
364
365 if index == FREE_LIST_END {
366 return AllocResult::WouldBlock;
367 }
368
369 let next = unsafe { self.read_next_free(index) };
370 let new_head = pack_free_head(next, tag.wrapping_add(1));
371
372 match header.free_head.compare_exchange_weak(
373 old_head,
374 new_head,
375 Ordering::AcqRel,
376 Ordering::Acquire,
377 ) {
378 Ok(_) => {
379 let meta = unsafe { self.meta(index) };
380 let result = meta.state.compare_exchange(
381 SlotState::Free as u32,
382 SlotState::Allocated as u32,
383 Ordering::AcqRel,
384 Ordering::Acquire,
385 );
386
387 if result.is_err() {
388 let current_state = meta.state.load(Ordering::Acquire);
393 if current_state == SlotState::Free as u32 {
394 self.push_to_free_list(index);
395 }
396 debug_assert_eq!(
399 current_state,
400 SlotState::Free as u32,
401 "slot popped from free list had unexpected state"
402 );
403 spin_loop();
404 continue;
405 }
406
407 let generation = meta.generation.fetch_add(1, Ordering::AcqRel) + 1;
408 return AllocResult::Ok(SlotHandle { index, generation });
409 }
410 Err(_) => {
411 spin_loop();
412 continue;
413 }
414 }
415 }
416 }
417
418 pub fn mark_in_flight(&self, handle: SlotHandle) -> Result<(), SlotError> {
420 if handle.index >= self.header().slot_count {
421 return Err(SlotError::InvalidIndex);
422 }
423
424 let meta = unsafe { self.meta(handle.index) };
425 let actual = meta.generation.load(Ordering::Acquire);
426 if actual != handle.generation {
427 return Err(SlotError::GenerationMismatch {
428 expected: handle.generation,
429 actual,
430 });
431 }
432
433 let result = meta.state.compare_exchange(
434 SlotState::Allocated as u32,
435 SlotState::InFlight as u32,
436 Ordering::AcqRel,
437 Ordering::Acquire,
438 );
439
440 result
441 .map(|_| ())
442 .map_err(|actual| SlotError::InvalidState {
443 expected: SlotState::Allocated,
444 actual: SlotState::from_u32(actual).unwrap_or(SlotState::Free),
445 })
446 }
447
448 pub fn free(&self, handle: SlotHandle) -> Result<(), SlotError> {
450 if handle.index >= self.header().slot_count {
451 return Err(SlotError::InvalidIndex);
452 }
453
454 let meta = unsafe { self.meta(handle.index) };
455 let actual = meta.generation.load(Ordering::Acquire);
456 if actual != handle.generation {
457 return Err(SlotError::GenerationMismatch {
458 expected: handle.generation,
459 actual,
460 });
461 }
462
463 let result = meta.state.compare_exchange(
464 SlotState::InFlight as u32,
465 SlotState::Free as u32,
466 Ordering::AcqRel,
467 Ordering::Acquire,
468 );
469
470 if result.is_ok() {
471 self.push_to_free_list(handle.index);
472 Ok(())
473 } else {
474 Err(SlotError::InvalidState {
475 expected: SlotState::InFlight,
476 actual: SlotState::from_u32(result.err().unwrap()).unwrap_or(SlotState::Free),
477 })
478 }
479 }
480
481 pub fn free_allocated(&self, handle: SlotHandle) -> Result<(), SlotError> {
483 if handle.index >= self.header().slot_count {
484 return Err(SlotError::InvalidIndex);
485 }
486
487 let meta = unsafe { self.meta(handle.index) };
488 let actual = meta.generation.load(Ordering::Acquire);
489 if actual != handle.generation {
490 return Err(SlotError::GenerationMismatch {
491 expected: handle.generation,
492 actual,
493 });
494 }
495
496 let result = meta.state.compare_exchange(
497 SlotState::Allocated as u32,
498 SlotState::Free as u32,
499 Ordering::AcqRel,
500 Ordering::Acquire,
501 );
502
503 if result.is_ok() {
504 self.push_to_free_list(handle.index);
505 Ok(())
506 } else {
507 Err(SlotError::InvalidState {
508 expected: SlotState::Allocated,
509 actual: SlotState::from_u32(result.err().unwrap()).unwrap_or(SlotState::Free),
510 })
511 }
512 }
513
514 #[inline]
520 pub unsafe fn slot_data_ptr(&self, handle: SlotHandle) -> *mut u8 {
521 unsafe { self.data_ptr(handle.index) }
522 }
523
524 #[inline]
526 pub fn slot_size(&self) -> u32 {
527 self.header().slot_size
528 }
529
530 #[inline]
532 pub fn slot_count(&self) -> u32 {
533 self.header().slot_count
534 }
535
536 pub fn free_count_approx(&self) -> u32 {
538 let slot_count = self.header().slot_count;
539 let mut free_list_len = 0u32;
540 let mut current = {
541 let (index, _tag) = unpack_free_head(self.header().free_head.load(Ordering::Acquire));
542 index
543 };
544
545 while current != FREE_LIST_END && free_list_len < slot_count {
546 free_list_len += 1;
547 if current < slot_count {
548 current = unsafe { self.read_next_free(current) };
549 } else {
550 break;
551 }
552 }
553
554 free_list_len
555 }
556
557 fn push_to_free_list(&self, index: u32) {
558 let header = self.header();
559
560 loop {
561 let old_head = header.free_head.load(Ordering::Acquire);
562 let (old_index, tag) = unpack_free_head(old_head);
563
564 unsafe { self.write_next_free(index, old_index) };
565
566 let new_head = pack_free_head(index, tag.wrapping_add(1));
567
568 if header
569 .free_head
570 .compare_exchange_weak(old_head, new_head, Ordering::AcqRel, Ordering::Acquire)
571 .is_ok()
572 {
573 return;
574 }
575 }
576 }
577}