use core::mem::MaybeUninit;
use crate::traits::{Initialize, InitializeExt as _, InitializeVectored, TrustedDeref};
use crate::wrappers::{AsUninit, AssertInit, AssertInitVectors, SingleVector};
#[derive(Debug)]
pub struct BufferInitializer<T> {
pub(crate) inner: T,
pub(crate) items_initialized: usize,
}
impl<T> BufferInitializer<T> {
#[inline]
pub const fn uninit(inner: T) -> Self {
Self {
inner,
items_initialized: 0,
}
}
#[inline]
pub fn into_inner(self) -> T {
self.inner
}
#[inline]
pub const fn items_initialized(&self) -> usize {
self.items_initialized
}
}
impl<T, Item> BufferInitializer<AsUninit<T>>
where
T: core::ops::Deref<Target = [Item]> + core::ops::DerefMut + TrustedDeref,
{
pub fn new(init: T) -> Self {
let mut this = Self::uninit(AsUninit(init));
unsafe {
this.advance_to_end();
}
this
}
}
impl<T> BufferInitializer<T>
where
T: Initialize,
{
pub(crate) fn debug_assert_validity(&self) {
debug_assert!(self.items_initialized <= self.capacity());
}
#[inline]
pub unsafe fn advance(&mut self, count: usize) {
self.items_initialized += count;
}
#[inline]
pub unsafe fn advance_to_end(&mut self) {
self.items_initialized = self.all_uninit().len();
}
#[inline]
pub unsafe fn assume_init(self) -> AssertInit<T> {
self.inner.assume_init()
}
#[inline]
pub fn into_raw_parts(self) -> (T, usize) {
let Self {
inner,
items_initialized,
} = self;
(inner, items_initialized)
}
#[inline]
pub fn all_uninit(&self) -> &[MaybeUninit<T::Item>] {
self.inner.as_maybe_uninit_slice()
}
#[inline]
pub unsafe fn all_uninit_mut(&mut self) -> &mut [MaybeUninit<T::Item>] {
self.inner.as_maybe_uninit_slice_mut()
}
#[inline]
pub fn capacity(&self) -> usize {
self.all_uninit().len()
}
#[inline]
pub fn remaining(&self) -> usize {
debug_assert!(self.capacity() >= self.items_initialized);
self.capacity().wrapping_sub(self.items_initialized)
}
#[inline]
pub fn is_completely_init(&self) -> bool {
self.items_initialized() == self.capacity()
}
#[inline]
pub fn is_completely_uninit(&self) -> bool {
self.items_initialized() == 0
}
#[inline]
pub fn uninit_part(&self) -> &[MaybeUninit<T::Item>] {
let all = self.all_uninit();
self.debug_assert_validity();
unsafe {
let ptr = all.as_ptr().add(self.items_initialized);
let len = all.len().wrapping_sub(self.items_initialized);
core::slice::from_raw_parts(ptr, len)
}
}
#[inline]
pub fn init_part(&self) -> &[T::Item] {
self.debug_assert_validity();
unsafe {
let ptr = self.all_uninit().as_ptr();
let len = self.items_initialized;
core::slice::from_raw_parts(ptr as *const T::Item, len)
}
}
#[inline]
pub fn uninit_part_mut(&mut self) -> &mut [MaybeUninit<T::Item>] {
let (orig_ptr, orig_len) = unsafe {
let orig = self.all_uninit_mut();
(orig.as_mut_ptr(), orig.len())
};
unsafe {
self.debug_assert_validity();
let ptr = orig_ptr.add(self.items_initialized);
let len = orig_len.wrapping_sub(self.items_initialized);
core::slice::from_raw_parts_mut(ptr, len)
}
}
#[inline]
pub fn init_part_mut(&mut self) -> &mut [T::Item] {
let orig_ptr = unsafe { self.all_uninit_mut().as_mut_ptr() };
unsafe {
let ptr = orig_ptr;
let len = self.items_initialized;
core::slice::from_raw_parts_mut(ptr as *mut T::Item, len)
}
}
#[inline]
pub fn try_into_init(self) -> Result<AssertInit<T>, Self> {
if self.is_completely_init() {
Ok(unsafe { self.assume_init() })
} else {
Err(self)
}
}
pub fn finish_init_by_filling(mut self, item: T::Item) -> AssertInit<T>
where
T::Item: Copy,
{
self.fill_uninit_part(item);
unsafe { self.assume_init() }
}
#[inline]
pub fn fill_uninit_part(&mut self, item: T::Item)
where
T::Item: Copy,
{
crate::fill_uninit_slice(self.uninit_part_mut(), item);
unsafe {
self.advance_to_end();
}
}
#[inline]
pub fn partially_fill_uninit_part(&mut self, count: usize, item: T::Item)
where
T::Item: Copy,
{
crate::fill_uninit_slice(&mut self.uninit_part_mut()[..count], item);
unsafe { self.advance(count) }
}
#[inline]
pub fn init_uninit_parts(&self) -> (&[T::Item], &[MaybeUninit<T::Item>]) {
(self.init_part(), self.uninit_part())
}
#[inline]
pub fn init_uninit_parts_mut(&mut self) -> (&mut [T::Item], &mut [MaybeUninit<T::Item>]) {
let (all_ptr, all_len) = unsafe {
let all = self.all_uninit_mut();
(all.as_mut_ptr(), all.len())
};
unsafe {
self.debug_assert_validity();
let init_base_ptr = all_ptr as *mut T::Item;
let init_len = self.items_initialized;
let uninit_base_ptr = all_ptr.add(self.items_initialized);
let uninit_len = all_len.wrapping_sub(self.items_initialized);
let init = core::slice::from_raw_parts_mut(init_base_ptr, init_len);
let uninit = core::slice::from_raw_parts_mut(uninit_base_ptr, uninit_len);
(init, uninit)
}
}
}
impl<T> BufferInitializer<T>
where
T: Initialize<Item = u8>,
{
pub fn finish_init_by_zeroing(self) -> AssertInit<T> {
self.finish_init_by_filling(0_u8)
}
#[inline]
pub fn partially_zero_uninit_part(&mut self, count: usize) {
crate::fill_uninit_slice(&mut self.uninit_part_mut()[..count], 0_u8);
unsafe { self.advance(count) }
}
#[inline]
pub fn zero_uninit_part(&mut self) {
self.fill_uninit_part(0_u8);
unsafe { self.advance_to_end() }
}
}
pub struct BuffersInitializer<T> {
pub(crate) inner: T,
pub(crate) vectors_initialized: usize,
pub(crate) items_initialized_for_vector: usize,
}
impl<T> BuffersInitializer<T> {
#[inline]
pub const fn uninit(inner: T) -> Self {
Self {
inner,
vectors_initialized: 0,
items_initialized_for_vector: 0,
}
}
#[inline]
pub fn into_raw_parts(self) -> (T, usize, usize) {
let Self {
inner,
vectors_initialized,
items_initialized_for_vector,
} = self;
(inner, vectors_initialized, items_initialized_for_vector)
}
#[inline]
pub fn into_inner(self) -> T {
let (inner, _, _) = self.into_raw_parts();
inner
}
}
impl<T> BuffersInitializer<SingleVector<T>> {
pub fn from_single_buffer_initializer(single: BufferInitializer<T>) -> Self {
let BufferInitializer {
items_initialized,
inner,
} = single;
Self {
items_initialized_for_vector: items_initialized,
vectors_initialized: 0,
inner: SingleVector(inner),
}
}
}
impl<T, Item> BuffersInitializer<T>
where
T: InitializeVectored,
T::UninitVector: Initialize<Item = Item>,
{
#[inline]
fn all_vectors_uninit(&self) -> &[T::UninitVector] {
self.inner.as_maybe_uninit_vectors()
}
#[inline]
pub fn current_vector_all(&self) -> Option<&[MaybeUninit<Item>]> {
self.debug_assert_validity();
let vectors_initialized = self.vectors_initialized;
if vectors_initialized != self.total_vector_count() {
Some(unsafe {
self.all_vectors_uninit()
.get_unchecked(vectors_initialized)
.as_maybe_uninit_slice()
})
} else {
None
}
}
#[inline]
pub unsafe fn current_vector_all_mut(&mut self) -> Option<&mut [MaybeUninit<Item>]> {
self.debug_assert_validity();
let vectors_initialized = self.vectors_initialized;
if vectors_initialized != self.total_vector_count() {
let all_vectors_uninit_mut = self.all_uninit_vectors_mut();
let current_vector_uninit_mut =
all_vectors_uninit_mut.get_unchecked_mut(vectors_initialized);
Some(current_vector_uninit_mut.as_maybe_uninit_slice_mut())
} else {
None
}
}
#[inline]
pub fn current_vector_init_part(&self) -> Option<&[Item]> {
let (init_part, _) = self.current_vector_init_uninit_parts()?;
Some(init_part)
}
#[inline]
pub fn current_vector_uninit_part(&self) -> Option<&[MaybeUninit<Item>]> {
let (_, uninit_part) = self.current_vector_init_uninit_parts()?;
Some(uninit_part)
}
#[inline]
pub fn current_vector_init_uninit_parts(&self) -> Option<(&[Item], &[MaybeUninit<Item>])> {
let vector = self.current_vector_all()?;
Some(unsafe {
let init_vector_base_ptr = vector.as_ptr() as *const Item;
let init_vector_len = self.items_initialized_for_vector;
let init_vector = core::slice::from_raw_parts(init_vector_base_ptr, init_vector_len);
let uninit_vector_base_ptr = vector.as_ptr().add(self.items_initialized_for_vector);
let uninit_vector_len = vector.len().wrapping_sub(self.items_initialized_for_vector);
let uninit_vector =
core::slice::from_raw_parts(uninit_vector_base_ptr, uninit_vector_len);
(init_vector, uninit_vector)
})
}
#[inline]
pub fn current_vector_init_part_mut(&mut self) -> Option<&mut [Item]> {
let (init_part_mut, _) = self.current_vector_init_uninit_parts_mut()?;
Some(init_part_mut)
}
#[inline]
pub fn current_vector_uninit_part_mut(&mut self) -> Option<&mut [MaybeUninit<Item>]> {
let (_, uninit_part_mut) = self.current_vector_init_uninit_parts_mut()?;
Some(uninit_part_mut)
}
#[inline]
pub fn current_vector_init_uninit_parts_mut(
&mut self,
) -> Option<(&mut [Item], &mut [MaybeUninit<Item>])> {
let (orig_base_ptr, orig_len) = unsafe {
let vector = self.current_vector_all_mut()?;
(vector.as_mut_ptr(), vector.len())
};
Some(unsafe {
let init_vector_base_ptr = orig_base_ptr as *mut Item;
let init_vector_len = self.items_initialized_for_vector;
let init_vector =
core::slice::from_raw_parts_mut(init_vector_base_ptr, init_vector_len);
let uninit_vector_base_ptr = orig_base_ptr.add(self.items_initialized_for_vector);
let uninit_vector_len = orig_len.wrapping_sub(self.items_initialized_for_vector);
let uninit_vector =
core::slice::from_raw_parts_mut(uninit_vector_base_ptr, uninit_vector_len);
(init_vector, uninit_vector)
})
}
fn debug_assert_validity(&self) {
debug_assert!(self
.inner
.as_maybe_uninit_vectors()
.get(self.vectors_initialized)
.map_or(true, |current_vector| current_vector
.as_maybe_uninit_slice()
.len()
>= self.items_initialized_for_vector));
debug_assert!(self.items_initialized_for_vector <= isize::MAX as usize);
debug_assert!(self.inner.as_maybe_uninit_vectors().len() >= self.vectors_initialized);
}
#[inline]
pub fn total_vector_count(&self) -> usize {
self.inner.as_maybe_uninit_vectors().len()
}
#[inline]
pub fn vectors_initialized(&self) -> usize {
self.vectors_initialized
}
#[inline]
pub fn vectors_remaining(&self) -> usize {
self.total_vector_count()
.wrapping_sub(self.vectors_initialized())
}
pub fn count_items_to_initialize(&self) -> usize {
let items_to_initialize_for_remaining = self
.all_uninit_vectors()
.iter()
.skip(self.vectors_initialized + 1_usize)
.map(|buffer| buffer.as_maybe_uninit_slice().len())
.sum::<usize>();
self.items_initialized_for_vector + items_to_initialize_for_remaining
}
pub fn count_total_items_in_all_vectors(&self) -> usize {
self.all_uninit_vectors()
.iter()
.map(|buffer| buffer.as_maybe_uninit_slice().len())
.sum()
}
#[inline]
pub unsafe fn items_initialized_for_vector_unchecked(&self, vector_index: usize) -> usize {
let ordering = vector_index.cmp(&self.vectors_initialized);
match ordering {
core::cmp::Ordering::Equal => self.items_initialized_for_vector,
core::cmp::Ordering::Greater => 0,
core::cmp::Ordering::Less => self
.all_uninit_vectors()
.get_unchecked(vector_index)
.as_maybe_uninit_slice()
.len(),
}
}
#[inline]
pub fn items_initialized_for_vector(&self, vector_index: usize) -> usize {
assert!(vector_index < self.total_vector_count());
unsafe { self.items_initialized_for_vector_unchecked(vector_index) }
}
#[inline]
pub fn items_initialized_for_current_vector(&self) -> usize {
if self.vectors_initialized() != self.total_vector_count() {
self.items_initialized_for_vector
} else {
0
}
}
#[inline]
pub fn all_uninit_vectors(&self) -> &[T::UninitVector] {
self.inner.as_maybe_uninit_vectors()
}
#[inline]
pub unsafe fn all_uninit_vectors_mut(&mut self) -> &mut [T::UninitVector] {
self.inner.as_maybe_uninit_vectors_mut()
}
#[inline]
pub unsafe fn advance(&mut self, mut count: usize) -> usize {
let mut items_advanced = 0;
while let Some(current_uninit_part) = self.current_vector_uninit_part() {
let current_uninit_part_len = current_uninit_part.len();
if count >= current_uninit_part_len {
self.vectors_initialized = self
.vectors_initialized
.checked_add(1)
.expect("reached usize::MAX when incrementing the buffer index");
self.items_initialized_for_vector = 0;
count -= current_uninit_part_len;
items_advanced -= current_uninit_part_len;
continue;
} else {
self.items_initialized_for_vector += current_uninit_part_len;
}
}
items_advanced
}
pub unsafe fn advance_current_vector_to_end(&mut self) {
self.debug_assert_validity();
self.vectors_initialized += 1;
self.items_initialized_for_vector = 0;
}
pub unsafe fn advance_current_vector(&mut self, count: usize) {
self.debug_assert_validity();
if let Some(current_vector) = self.current_vector_all() {
let current_vector_len = current_vector.len();
let end = self.items_initialized_for_vector + count;
assert!(end <= current_vector_len);
if end == current_vector_len {
self.vectors_initialized += 1;
self.items_initialized_for_vector = 0;
} else {
self.items_initialized_for_vector = end;
}
} else if count > 0 {
panic!("cannot advance beyond the end of the current vector")
}
}
pub fn partially_fill_current_vector_uninit_part(&mut self, count: usize, item: Item)
where
Item: Copy,
{
if let Some(current_vector_uninit_part_mut) = self.current_vector_uninit_part_mut() {
crate::fill_uninit_slice(&mut current_vector_uninit_part_mut[..count], item);
unsafe { self.advance_current_vector(count) }
} else if count > 0 {
panic!("cannot partially fill a vector when none are left");
}
}
pub fn fill_current_vector_uninit_part(&mut self, item: Item)
where
Item: Copy,
{
if let Some(current_vector_uninit_part_mut) = self.current_vector_uninit_part_mut() {
crate::fill_uninit_slice(current_vector_uninit_part_mut, item);
unsafe { self.advance_current_vector_to_end() }
}
}
pub fn try_into_init(self) -> Result<AssertInitVectors<T>, Self> {
if self.vectors_remaining() == 0 {
Ok(unsafe { AssertInitVectors::new_unchecked(self.into_inner()) })
} else {
Err(self)
}
}
}
impl<T> BuffersInitializer<T>
where
T: InitializeVectored,
T::UninitVector: Initialize<Item = u8>,
{
pub fn partially_zero_current_vector_uninit_part(&mut self, count: usize) {
self.partially_fill_current_vector_uninit_part(count, 0_u8)
}
pub fn zero_current_vector_uninit_part(&mut self) {
self.fill_current_vector_uninit_part(0_u8)
}
}
#[cfg(test)]
mod tests {
use super::*;
mod single {
use super::*;
#[test]
fn new_fills_completely() {
let slice = *b"Calling BufferInitializer::new() will ensure that the initialization marker is put at the end of the slice, making it fully zero-cost when already using initialized memory.";
let mut copy = slice;
let mut initializer = BufferInitializer::new(&mut copy[..]);
assert_eq!(initializer.remaining(), 0);
assert_eq!(initializer.capacity(), slice.len());
assert!(initializer.is_completely_init());
assert!(!initializer.is_completely_uninit());
assert!(initializer.uninit_part().is_empty());
assert!(initializer.uninit_part_mut().is_empty());
assert_eq!(initializer.init_part(), slice);
assert_eq!(initializer.init_part_mut(), slice);
assert!(initializer.try_into_init().is_ok());
}
#[test]
fn basic_initialization() {
let mut slice = [MaybeUninit::uninit(); 32];
let buffer = BufferInitializer::uninit(&mut slice[..]);
let initialized = buffer.finish_init_by_filling(42_u8);
assert!(initialized.iter().all(|&byte| byte == 42_u8));
}
#[test]
fn buffer_parts() {
let mut slice = [MaybeUninit::<u8>::uninit(); 32];
let mut buffer = BufferInitializer::uninit(&mut slice[..]);
assert_eq!(buffer.uninit_part().len(), 32);
assert_eq!(buffer.uninit_part_mut().len(), 32);
assert_eq!(buffer.init_part(), &[]);
assert_eq!(buffer.init_part_mut(), &mut []);
assert!(!buffer.is_completely_init());
}
}
mod vectored {
use super::*;
#[test]
fn fill_uninit_part() {
let mut first = [MaybeUninit::uninit(); 32];
let mut second = [MaybeUninit::uninit(); 128];
let mut third = [MaybeUninit::uninit(); 64];
let mut vectors = [&mut first[..], &mut second[..], &mut third[..]];
let mut initializer = BuffersInitializer::uninit(&mut vectors[..]);
initializer.zero_current_vector_uninit_part();
assert_eq!(initializer.vectors_initialized(), 1);
assert_eq!(initializer.items_initialized_for_current_vector(), 0);
initializer.partially_zero_current_vector_uninit_part(96);
assert_eq!(initializer.vectors_initialized(), 1);
assert_eq!(initializer.items_initialized_for_current_vector(), 96);
initializer.partially_fill_current_vector_uninit_part(32, 0x13_u8);
assert_eq!(initializer.vectors_initialized(), 2);
assert_eq!(initializer.items_initialized_for_current_vector(), 0);
initializer.partially_fill_current_vector_uninit_part(16, 0x37_u8);
assert_eq!(initializer.vectors_initialized(), 2);
assert_eq!(initializer.items_initialized_for_current_vector(), 16);
initializer.fill_current_vector_uninit_part(0x42);
assert_eq!(initializer.vectors_initialized(), 3);
assert!(initializer.current_vector_all().is_none());
}
}
}