1use core::mem::MaybeUninit;
5use std::any::type_name;
6use std::fmt::{Debug, Formatter};
7use std::io::Write;
8use std::ops::{Deref, DerefMut};
9
10use bytes::buf::UninitSlice;
11use bytes::{Buf, BufMut, BytesMut};
12use vortex_error::{VortexExpect, vortex_panic};
13
14use crate::debug::TruncatedDebug;
15use crate::trusted_len::TrustedLen;
16use crate::{Alignment, Buffer, ByteBufferMut};
17
18#[derive(PartialEq, Eq)]
20pub struct BufferMut<T> {
21 pub(crate) bytes: BytesMut,
22 pub(crate) length: usize,
23 pub(crate) alignment: Alignment,
24 pub(crate) _marker: std::marker::PhantomData<T>,
25}
26
27impl<T> BufferMut<T> {
28 pub fn with_capacity(capacity: usize) -> Self {
30 Self::with_capacity_aligned(capacity, Alignment::of::<T>())
31 }
32
33 pub fn with_capacity_aligned(capacity: usize, alignment: Alignment) -> Self {
35 if !alignment.is_aligned_to(Alignment::of::<T>()) {
36 vortex_panic!(
37 "Alignment {} must align to the scalar type's alignment {}",
38 alignment,
39 align_of::<T>()
40 );
41 }
42
43 let mut bytes = BytesMut::with_capacity((capacity * size_of::<T>()) + *alignment);
44 bytes.align_empty(alignment);
45
46 Self {
47 bytes,
48 length: 0,
49 alignment,
50 _marker: Default::default(),
51 }
52 }
53
54 pub fn zeroed(len: usize) -> Self {
56 Self::zeroed_aligned(len, Alignment::of::<T>())
57 }
58
59 pub fn zeroed_aligned(len: usize, alignment: Alignment) -> Self {
61 let mut bytes = BytesMut::zeroed((len * size_of::<T>()) + *alignment);
62 bytes.advance(bytes.as_ptr().align_offset(*alignment));
63 unsafe { bytes.set_len(len * size_of::<T>()) };
64 Self {
65 bytes,
66 length: len,
67 alignment,
68 _marker: Default::default(),
69 }
70 }
71
72 pub fn empty() -> Self {
74 Self::empty_aligned(Alignment::of::<T>())
75 }
76
77 pub fn empty_aligned(alignment: Alignment) -> Self {
79 BufferMut::with_capacity_aligned(0, alignment)
80 }
81
82 pub fn full(item: T, len: usize) -> Self
84 where
85 T: Copy,
86 {
87 let mut buffer = BufferMut::<T>::with_capacity(len);
88 buffer.push_n(item, len);
89 buffer
90 }
91
92 pub fn copy_from(other: impl AsRef<[T]>) -> Self {
94 Self::copy_from_aligned(other, Alignment::of::<T>())
95 }
96
97 pub fn copy_from_aligned(other: impl AsRef<[T]>, alignment: Alignment) -> Self {
103 if !alignment.is_aligned_to(Alignment::of::<T>()) {
104 vortex_panic!("Given alignment is not aligned to type T")
105 }
106 let other = other.as_ref();
107 let mut buffer = Self::with_capacity_aligned(other.len(), alignment);
108 buffer.extend_from_slice(other);
109 debug_assert_eq!(buffer.alignment(), alignment);
110 buffer
111 }
112
113 #[inline(always)]
115 pub fn alignment(&self) -> Alignment {
116 self.alignment
117 }
118
119 #[inline(always)]
121 pub fn len(&self) -> usize {
122 debug_assert_eq!(self.length, self.bytes.len() / size_of::<T>());
123 self.length
124 }
125
126 #[inline(always)]
128 pub fn is_empty(&self) -> bool {
129 self.length == 0
130 }
131
132 #[inline]
134 pub fn capacity(&self) -> usize {
135 self.bytes.capacity() / size_of::<T>()
136 }
137
138 #[inline]
140 pub fn as_slice(&self) -> &[T] {
141 let raw_slice = self.bytes.as_ref();
142 unsafe { std::slice::from_raw_parts(raw_slice.as_ptr().cast(), self.length) }
144 }
145
146 #[inline]
148 pub fn as_mut_slice(&mut self) -> &mut [T] {
149 let raw_slice = self.bytes.as_mut();
150 unsafe { std::slice::from_raw_parts_mut(raw_slice.as_mut_ptr().cast(), self.length) }
152 }
153
154 #[inline]
156 pub fn clear(&mut self) {
157 unsafe { self.bytes.set_len(0) }
158 self.length = 0;
159 }
160
161 #[inline]
169 pub fn truncate(&mut self, len: usize) {
170 if len <= self.len() {
171 unsafe { self.set_len(len) };
173 }
174 }
175
176 #[inline]
178 pub fn reserve(&mut self, additional: usize) {
179 let additional_bytes = additional * size_of::<T>();
180 if additional_bytes <= self.bytes.capacity() - self.bytes.len() {
181 return;
183 }
184
185 self.reserve_allocate(additional);
187 }
188
189 fn reserve_allocate(&mut self, additional: usize) {
192 let new_capacity: usize = ((self.length + additional) * size_of::<T>()) + *self.alignment;
193 let new_capacity = new_capacity.max(self.bytes.capacity() * 2);
195
196 let mut bytes = BytesMut::with_capacity(new_capacity);
197 bytes.align_empty(self.alignment);
198 bytes.extend_from_slice(&self.bytes);
199 self.bytes = bytes;
200 }
201
202 #[inline]
234 pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] {
235 let dst = self.bytes.spare_capacity_mut().as_mut_ptr();
236 unsafe {
237 std::slice::from_raw_parts_mut(
238 dst as *mut MaybeUninit<T>,
239 self.capacity() - self.length,
240 )
241 }
242 }
243
244 #[inline]
253 pub unsafe fn set_len(&mut self, len: usize) {
254 debug_assert!(len <= self.capacity());
255 unsafe { self.bytes.set_len(len * size_of::<T>()) };
256 self.length = len;
257 }
258
259 #[inline]
261 pub fn push(&mut self, value: T) {
262 self.reserve(1);
263 unsafe { self.push_unchecked(value) }
264 }
265
266 #[inline]
272 pub unsafe fn push_unchecked(&mut self, item: T) {
273 unsafe {
275 let dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
276 dst.write(item);
277 self.bytes.set_len(self.bytes.len() + size_of::<T>())
278 }
279 self.length += 1;
280 }
281
282 #[inline]
286 pub fn push_n(&mut self, item: T, n: usize)
287 where
288 T: Copy,
289 {
290 self.reserve(n);
291 unsafe { self.push_n_unchecked(item, n) }
292 }
293
294 #[inline]
300 pub unsafe fn push_n_unchecked(&mut self, item: T, n: usize)
301 where
302 T: Copy,
303 {
304 let mut dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
305 unsafe {
307 let end = dst.add(n);
308 while dst < end {
309 dst.write(item);
310 dst = dst.add(1);
311 }
312 self.bytes.set_len(self.bytes.len() + (n * size_of::<T>()));
313 }
314 self.length += n;
315 }
316
317 #[inline]
330 pub fn extend_from_slice(&mut self, slice: &[T]) {
331 self.reserve(slice.len());
332 let raw_slice =
333 unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
334 self.bytes.extend_from_slice(raw_slice);
335 self.length += slice.len();
336 }
337
338 pub fn split_off(&mut self, at: usize) -> Self {
348 if at > self.capacity() {
349 vortex_panic!("Cannot split buffer of capacity {} at {}", self.len(), at);
350 }
351
352 let bytes_at = at * size_of::<T>();
353 if !bytes_at.is_multiple_of(*self.alignment) {
354 vortex_panic!(
355 "Cannot split buffer at {}, resulting alignment is not {}",
356 at,
357 self.alignment
358 );
359 }
360
361 let new_bytes = self.bytes.split_off(bytes_at);
362
363 let new_length = self.length.saturating_sub(at);
365 self.length = self.length.min(at);
366
367 BufferMut {
368 bytes: new_bytes,
369 length: new_length,
370 alignment: self.alignment,
371 _marker: Default::default(),
372 }
373 }
374
375 pub fn unsplit(&mut self, other: Self) {
383 if self.alignment != other.alignment {
384 vortex_panic!(
385 "Cannot unsplit buffers with different alignments: {} and {}",
386 self.alignment,
387 other.alignment
388 );
389 }
390 self.bytes.unsplit(other.bytes);
391 self.length += other.length;
392 }
393
394 pub fn freeze(self) -> Buffer<T> {
396 Buffer {
397 bytes: self.bytes.freeze(),
398 length: self.length,
399 alignment: self.alignment,
400 _marker: Default::default(),
401 }
402 }
403
404 pub fn map_each_in_place<R, F>(self, mut f: F) -> BufferMut<R>
406 where
407 T: Copy,
408 F: FnMut(T) -> R,
409 {
410 assert_eq!(
411 size_of::<T>(),
412 size_of::<R>(),
413 "Size of T and R do not match"
414 );
415 let mut buf: BufferMut<R> = unsafe { std::mem::transmute(self) };
417 buf.iter_mut()
418 .for_each(|item| *item = f(unsafe { std::mem::transmute_copy(item) }));
419 buf
420 }
421
422 pub fn aligned(self, alignment: Alignment) -> Self {
424 if self.as_ptr().align_offset(*alignment) == 0 {
425 self
426 } else {
427 Self::copy_from_aligned(self, alignment)
428 }
429 }
430}
431
432impl<T> Clone for BufferMut<T> {
433 fn clone(&self) -> Self {
434 let mut buffer = BufferMut::<T>::with_capacity_aligned(self.capacity(), self.alignment);
437 buffer.extend_from_slice(self.as_slice());
438 buffer
439 }
440}
441
442impl<T: Debug> Debug for BufferMut<T> {
443 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
444 f.debug_struct(&format!("BufferMut<{}>", type_name::<T>()))
445 .field("length", &self.length)
446 .field("alignment", &self.alignment)
447 .field("as_slice", &TruncatedDebug(self.as_slice()))
448 .finish()
449 }
450}
451
452impl<T> Default for BufferMut<T> {
453 fn default() -> Self {
454 Self::empty()
455 }
456}
457
458impl<T> Deref for BufferMut<T> {
459 type Target = [T];
460
461 #[inline]
462 fn deref(&self) -> &Self::Target {
463 self.as_slice()
464 }
465}
466
467impl<T> DerefMut for BufferMut<T> {
468 #[inline]
469 fn deref_mut(&mut self) -> &mut Self::Target {
470 self.as_mut_slice()
471 }
472}
473
474impl<T> AsRef<[T]> for BufferMut<T> {
475 #[inline]
476 fn as_ref(&self) -> &[T] {
477 self.as_slice()
478 }
479}
480
481impl<T> AsMut<[T]> for BufferMut<T> {
482 #[inline]
483 fn as_mut(&mut self) -> &mut [T] {
484 self.as_mut_slice()
485 }
486}
487
488impl<T> BufferMut<T> {
489 fn extend_iter(&mut self, mut iter: impl Iterator<Item = T>) {
494 let (lower_bound, _) = iter.size_hint();
497
498 self.reserve(lower_bound);
501
502 let unwritten = self.capacity() - self.len();
503
504 let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
506 let mut dst: *mut T = begin.cast_mut();
507
508 for _ in 0..unwritten {
510 let Some(item) = iter.next() else {
511 break;
513 };
514
515 unsafe { dst.write(item) };
518
519 unsafe { dst = dst.add(1) };
524 }
525
526 let items_written = unsafe { dst.offset_from_unsigned(begin) };
529 let length = self.len() + items_written;
530
531 unsafe { self.set_len(length) };
533
534 iter.for_each(|item| self.push(item));
537 }
538
539 pub fn extend_trusted<I: TrustedLen<Item = T>>(&mut self, iter: I) {
544 let (_, upper_bound) = iter.size_hint();
547 self.reserve(
548 upper_bound
549 .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
550 );
551
552 let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
554 let mut dst: *mut T = begin.cast_mut();
555
556 iter.for_each(|item| {
557 unsafe { dst.write(item) };
560
561 unsafe { dst = dst.add(1) };
566 });
567
568 let items_written = unsafe { dst.offset_from_unsigned(begin) };
571 let length = self.len() + items_written;
572
573 unsafe { self.set_len(length) };
575 }
576
577 pub fn from_trusted_len_iter<I>(iter: I) -> Self
581 where
582 I: TrustedLen<Item = T>,
583 {
584 let (_, upper_bound) = iter.size_hint();
585 let mut buffer = Self::with_capacity(
586 upper_bound
587 .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
588 );
589
590 buffer.extend_trusted(iter);
591 buffer
592 }
593}
594
595impl<T> Extend<T> for BufferMut<T> {
596 #[inline]
597 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
598 self.extend_iter(iter.into_iter())
599 }
600}
601
602impl<'a, T> Extend<&'a T> for BufferMut<T>
603where
604 T: Copy + 'a,
605{
606 #[inline]
607 fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
608 self.extend_iter(iter.into_iter().copied())
609 }
610}
611
612impl<T> FromIterator<T> for BufferMut<T> {
613 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
614 let mut buffer = Self::with_capacity(0);
616 buffer.extend(iter);
617 buffer
618 }
619}
620
621impl Buf for ByteBufferMut {
622 fn remaining(&self) -> usize {
623 self.len()
624 }
625
626 fn chunk(&self) -> &[u8] {
627 self.as_slice()
628 }
629
630 fn advance(&mut self, cnt: usize) {
631 if !cnt.is_multiple_of(*self.alignment) {
632 vortex_panic!(
633 "Cannot advance buffer by {} items, resulting alignment is not {}",
634 cnt,
635 self.alignment
636 );
637 }
638 self.bytes.advance(cnt);
639 self.length -= cnt;
640 }
641}
642
643unsafe impl BufMut for ByteBufferMut {
647 #[inline]
648 fn remaining_mut(&self) -> usize {
649 usize::MAX - self.len()
650 }
651
652 #[inline]
653 unsafe fn advance_mut(&mut self, cnt: usize) {
654 if !cnt.is_multiple_of(*self.alignment) {
655 vortex_panic!(
656 "Cannot advance buffer by {} items, resulting alignment is not {}",
657 cnt,
658 self.alignment
659 );
660 }
661 unsafe { self.bytes.advance_mut(cnt) };
662 self.length -= cnt;
663 }
664
665 #[inline]
666 fn chunk_mut(&mut self) -> &mut UninitSlice {
667 self.bytes.chunk_mut()
668 }
669
670 fn put<T: Buf>(&mut self, mut src: T)
671 where
672 Self: Sized,
673 {
674 while src.has_remaining() {
675 let chunk = src.chunk();
676 self.extend_from_slice(chunk);
677 src.advance(chunk.len());
678 }
679 }
680
681 #[inline]
682 fn put_slice(&mut self, src: &[u8]) {
683 self.extend_from_slice(src);
684 }
685
686 #[inline]
687 fn put_bytes(&mut self, val: u8, cnt: usize) {
688 self.push_n(val, cnt)
689 }
690}
691
692trait AlignedBytesMut {
694 fn align_empty(&mut self, alignment: Alignment);
700}
701
702impl AlignedBytesMut for BytesMut {
703 fn align_empty(&mut self, alignment: Alignment) {
704 if !self.is_empty() {
706 vortex_panic!("ByteBufferMut must be empty");
707 }
708
709 let padding = self.as_ptr().align_offset(*alignment);
710 self.capacity()
711 .checked_sub(padding)
712 .vortex_expect("Not enough capacity to align buffer");
713
714 unsafe { self.set_len(padding) };
717 self.advance(padding);
718 }
719}
720
721impl Write for ByteBufferMut {
722 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
723 self.extend_from_slice(buf);
724 Ok(buf.len())
725 }
726
727 fn flush(&mut self) -> std::io::Result<()> {
728 Ok(())
729 }
730}
731
732#[cfg(test)]
733mod test {
734 use bytes::{Buf, BufMut};
735
736 use crate::{Alignment, BufferMut, ByteBufferMut, buffer_mut};
737
738 #[test]
739 fn capacity() {
740 let mut n = 57;
741 let mut buf = BufferMut::<i32>::with_capacity_aligned(n, Alignment::new(1024));
742 assert!(buf.capacity() >= 57);
743
744 while n > 0 {
745 buf.push(0);
746 assert!(buf.capacity() >= n);
747 n -= 1
748 }
749
750 assert_eq!(buf.alignment(), Alignment::new(1024));
751 }
752
753 #[test]
754 fn from_iter() {
755 let buf = BufferMut::from_iter([0, 10, 20, 30]);
756 assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
757 }
758
759 #[test]
760 fn extend() {
761 let mut buf = BufferMut::empty();
762 buf.extend([0i32, 10, 20, 30]);
763 buf.extend([40, 50, 60]);
764 assert_eq!(buf.as_slice(), &[0, 10, 20, 30, 40, 50, 60]);
765 }
766
767 #[test]
768 fn push() {
769 let mut buf = BufferMut::empty();
770 buf.push(1);
771 buf.push(2);
772 buf.push(3);
773 assert_eq!(buf.as_slice(), &[1, 2, 3]);
774 }
775
776 #[test]
777 fn push_n() {
778 let mut buf = BufferMut::empty();
779 buf.push_n(0, 100);
780 assert_eq!(buf.as_slice(), &[0; 100]);
781 }
782
783 #[test]
784 fn as_mut() {
785 let mut buf = buffer_mut![0, 1, 2];
786 buf[1] = 0;
788 buf.as_mut()[2] = 0;
790 assert_eq!(buf.as_slice(), &[0, 0, 0]);
791 }
792
793 #[test]
794 fn map_each() {
795 let buf = buffer_mut![0i32, 1, 2];
796 let buf = buf.map_each_in_place(|i| (i + 1) as u32);
798 assert_eq!(buf.as_slice(), &[1u32, 2, 3]);
799 }
800
801 #[test]
802 fn bytes_buf() {
803 let mut buf = ByteBufferMut::copy_from("helloworld".as_bytes());
804 assert_eq!(buf.remaining(), 10);
805 assert_eq!(buf.chunk(), b"helloworld");
806
807 Buf::advance(&mut buf, 5);
808 assert_eq!(buf.remaining(), 5);
809 assert_eq!(buf.as_slice(), b"world");
810 assert_eq!(buf.chunk(), b"world");
811 }
812
813 #[test]
814 fn bytes_buf_mut() {
815 let mut buf = ByteBufferMut::copy_from("hello".as_bytes());
816 assert_eq!(BufMut::remaining_mut(&buf), usize::MAX - 5);
817
818 BufMut::put_slice(&mut buf, b"world");
819 assert_eq!(buf.as_slice(), b"helloworld");
820 }
821}