1use core::mem::MaybeUninit;
5use std::any::type_name;
6use std::fmt::Debug;
7use std::fmt::Formatter;
8use std::io::Write;
9use std::ops::Deref;
10use std::ops::DerefMut;
11
12use bytes::Buf;
13use bytes::BufMut;
14use bytes::BytesMut;
15use bytes::buf::UninitSlice;
16use vortex_error::VortexExpect;
17use vortex_error::vortex_panic;
18
19use crate::Alignment;
20use crate::Buffer;
21use crate::ByteBufferMut;
22use crate::debug::TruncatedDebug;
23use crate::trusted_len::TrustedLen;
24
25#[derive(PartialEq, Eq)]
27pub struct BufferMut<T> {
28 pub(crate) bytes: BytesMut,
29 pub(crate) length: usize,
30 pub(crate) alignment: Alignment,
31 pub(crate) _marker: std::marker::PhantomData<T>,
32}
33
34impl<T> BufferMut<T> {
35 pub fn with_capacity(capacity: usize) -> Self {
37 Self::with_capacity_aligned(capacity, Alignment::of::<T>())
38 }
39
40 pub fn with_capacity_aligned(capacity: usize, alignment: Alignment) -> Self {
42 if !alignment.is_aligned_to(Alignment::of::<T>()) {
43 vortex_panic!(
44 "Alignment {} must align to the scalar type's alignment {}",
45 alignment,
46 align_of::<T>()
47 );
48 }
49
50 let mut bytes = BytesMut::with_capacity((capacity * size_of::<T>()) + *alignment);
51 bytes.align_empty(alignment);
52
53 Self {
54 bytes,
55 length: 0,
56 alignment,
57 _marker: Default::default(),
58 }
59 }
60
61 pub fn zeroed(len: usize) -> Self {
63 Self::zeroed_aligned(len, Alignment::of::<T>())
64 }
65
66 pub fn zeroed_aligned(len: usize, alignment: Alignment) -> Self {
68 let mut bytes = BytesMut::zeroed((len * size_of::<T>()) + *alignment);
69 bytes.advance(bytes.as_ptr().align_offset(*alignment));
70 unsafe { bytes.set_len(len * size_of::<T>()) };
71 let actual_len = bytes.len().checked_div(size_of::<T>()).unwrap_or(0);
72 Self {
73 bytes,
74 length: actual_len,
75 alignment,
76 _marker: Default::default(),
77 }
78 }
79
80 pub fn empty() -> Self {
82 Self::empty_aligned(Alignment::of::<T>())
83 }
84
85 pub fn empty_aligned(alignment: Alignment) -> Self {
87 BufferMut::with_capacity_aligned(0, alignment)
88 }
89
90 pub fn full(item: T, len: usize) -> Self
92 where
93 T: Copy,
94 {
95 let mut buffer = BufferMut::<T>::with_capacity(len);
96 buffer.push_n(item, len);
97 buffer
98 }
99
100 pub fn copy_from(other: impl AsRef<[T]>) -> Self {
102 Self::copy_from_aligned(other, Alignment::of::<T>())
103 }
104
105 pub fn copy_from_aligned(other: impl AsRef<[T]>, alignment: Alignment) -> Self {
111 if !alignment.is_aligned_to(Alignment::of::<T>()) {
112 vortex_panic!("Given alignment is not aligned to type T")
113 }
114 let other = other.as_ref();
115 let mut buffer = Self::with_capacity_aligned(other.len(), alignment);
116 buffer.extend_from_slice(other);
117 debug_assert_eq!(buffer.alignment(), alignment);
118 buffer
119 }
120
121 #[inline(always)]
123 pub fn alignment(&self) -> Alignment {
124 self.alignment
125 }
126
127 #[inline(always)]
129 pub fn len(&self) -> usize {
130 debug_assert_eq!(self.length, self.bytes.len() / size_of::<T>());
131 self.length
132 }
133
134 #[inline(always)]
136 pub fn is_empty(&self) -> bool {
137 self.length == 0
138 }
139
140 #[inline]
142 pub fn capacity(&self) -> usize {
143 self.bytes.capacity() / size_of::<T>()
144 }
145
146 #[inline]
148 pub fn as_slice(&self) -> &[T] {
149 let raw_slice = self.bytes.as_ref();
150 unsafe { std::slice::from_raw_parts(raw_slice.as_ptr().cast(), self.length) }
152 }
153
154 #[inline]
156 pub fn as_mut_slice(&mut self) -> &mut [T] {
157 let raw_slice = self.bytes.as_mut();
158 unsafe { std::slice::from_raw_parts_mut(raw_slice.as_mut_ptr().cast(), self.length) }
160 }
161
162 #[inline]
164 pub fn clear(&mut self) {
165 unsafe { self.bytes.set_len(0) }
166 self.length = 0;
167 }
168
169 #[inline]
177 pub fn truncate(&mut self, len: usize) {
178 if len <= self.len() {
179 unsafe { self.set_len(len) };
181 }
182 }
183
184 #[inline]
186 pub fn reserve(&mut self, additional: usize) {
187 let additional_bytes = additional * size_of::<T>();
188 if additional_bytes <= self.bytes.capacity() - self.bytes.len() {
189 return;
191 }
192
193 self.reserve_allocate(additional);
195 }
196
197 fn reserve_allocate(&mut self, additional: usize) {
200 let new_capacity: usize = ((self.length + additional) * size_of::<T>()) + *self.alignment;
201 let new_capacity = new_capacity.max(self.bytes.capacity() * 2);
203
204 let mut bytes = BytesMut::with_capacity(new_capacity);
205 bytes.align_empty(self.alignment);
206 bytes.extend_from_slice(&self.bytes);
207 self.bytes = bytes;
208 }
209
210 #[inline]
242 pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] {
243 let dst = self.bytes.spare_capacity_mut().as_mut_ptr();
244 unsafe {
245 std::slice::from_raw_parts_mut(
246 dst as *mut MaybeUninit<T>,
247 self.capacity() - self.length,
248 )
249 }
250 }
251
252 #[inline]
261 pub unsafe fn set_len(&mut self, len: usize) {
262 debug_assert!(len <= self.capacity());
263 unsafe { self.bytes.set_len(len * size_of::<T>()) };
264 self.length = len;
265 }
266
267 #[inline]
269 pub fn push(&mut self, value: T) {
270 self.reserve(1);
271 unsafe { self.push_unchecked(value) }
272 }
273
274 #[inline]
280 pub unsafe fn push_unchecked(&mut self, item: T) {
281 unsafe {
283 let dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
284 dst.write(item);
285 self.bytes.set_len(self.bytes.len() + size_of::<T>())
286 }
287 self.length += 1;
288 }
289
290 #[inline]
294 pub fn push_n(&mut self, item: T, n: usize)
295 where
296 T: Copy,
297 {
298 self.reserve(n);
299 unsafe { self.push_n_unchecked(item, n) }
300 }
301
302 #[inline]
308 pub unsafe fn push_n_unchecked(&mut self, item: T, n: usize)
309 where
310 T: Copy,
311 {
312 let mut dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
313 unsafe {
315 let end = dst.add(n);
316 while dst < end {
317 dst.write(item);
318 dst = dst.add(1);
319 }
320 self.bytes.set_len(self.bytes.len() + (n * size_of::<T>()));
321 }
322 self.length += n;
323 }
324
325 #[inline]
338 pub fn extend_from_slice(&mut self, slice: &[T]) {
339 self.reserve(slice.len());
340 let raw_slice =
341 unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
342 self.bytes.extend_from_slice(raw_slice);
343 self.length += slice.len();
344 }
345
346 pub fn split_off(&mut self, at: usize) -> Self {
356 if at > self.capacity() {
357 vortex_panic!("Cannot split buffer of capacity {} at {}", self.len(), at);
358 }
359
360 let bytes_at = at * size_of::<T>();
361 if !bytes_at.is_multiple_of(*self.alignment) {
362 vortex_panic!(
363 "Cannot split buffer at {}, resulting alignment is not {}",
364 at,
365 self.alignment
366 );
367 }
368
369 let new_bytes = self.bytes.split_off(bytes_at);
370
371 let new_length = self.length.saturating_sub(at);
373 self.length = self.length.min(at);
374
375 BufferMut {
376 bytes: new_bytes,
377 length: new_length,
378 alignment: self.alignment,
379 _marker: Default::default(),
380 }
381 }
382
383 pub fn unsplit(&mut self, other: Self) {
391 if self.alignment != other.alignment {
392 vortex_panic!(
393 "Cannot unsplit buffers with different alignments: {} and {}",
394 self.alignment,
395 other.alignment
396 );
397 }
398 self.bytes.unsplit(other.bytes);
399 self.length += other.length;
400 }
401
402 pub fn into_byte_buffer(self) -> ByteBufferMut {
404 ByteBufferMut {
405 bytes: self.bytes,
406 length: self.length * size_of::<T>(),
407 alignment: self.alignment,
408 _marker: Default::default(),
409 }
410 }
411
412 pub fn freeze(self) -> Buffer<T> {
414 Buffer {
415 bytes: self.bytes.freeze(),
416 length: self.length,
417 alignment: self.alignment,
418 _marker: Default::default(),
419 }
420 }
421
422 pub fn map_each_in_place<R, F>(self, mut f: F) -> BufferMut<R>
424 where
425 T: Copy,
426 F: FnMut(T) -> R,
427 {
428 assert_eq!(
429 size_of::<T>(),
430 size_of::<R>(),
431 "Size of T and R do not match"
432 );
433 let mut buf: BufferMut<R> = unsafe { std::mem::transmute(self) };
435 buf.iter_mut()
436 .for_each(|item| *item = f(unsafe { std::mem::transmute_copy(item) }));
437 buf
438 }
439
440 pub fn aligned(self, alignment: Alignment) -> Self {
446 if self.as_ptr().align_offset(*alignment) == 0 {
447 Self {
448 bytes: self.bytes,
449 length: self.length,
450 alignment,
451 _marker: std::marker::PhantomData,
452 }
453 } else {
454 Self::copy_from_aligned(self, alignment)
455 }
456 }
457
458 pub unsafe fn transmute<U>(self) -> BufferMut<U> {
470 assert_eq!(size_of::<T>(), size_of::<U>(), "Buffer type size mismatch");
471 assert_eq!(
472 align_of::<T>(),
473 align_of::<U>(),
474 "Buffer type alignment mismatch"
475 );
476
477 BufferMut {
478 bytes: self.bytes,
479 length: self.length,
480 alignment: self.alignment,
481 _marker: std::marker::PhantomData,
482 }
483 }
484}
485
486impl<T> Clone for BufferMut<T> {
487 fn clone(&self) -> Self {
488 let mut buffer = BufferMut::<T>::with_capacity_aligned(self.capacity(), self.alignment);
491 buffer.extend_from_slice(self.as_slice());
492 buffer
493 }
494}
495
496impl<T: Debug> Debug for BufferMut<T> {
497 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
498 f.debug_struct(&format!("BufferMut<{}>", type_name::<T>()))
499 .field("length", &self.length)
500 .field("alignment", &self.alignment)
501 .field("as_slice", &TruncatedDebug(self.as_slice()))
502 .finish()
503 }
504}
505
506impl<T> Default for BufferMut<T> {
507 fn default() -> Self {
508 Self::empty()
509 }
510}
511
512impl<T> Deref for BufferMut<T> {
513 type Target = [T];
514
515 #[inline]
516 fn deref(&self) -> &Self::Target {
517 self.as_slice()
518 }
519}
520
521impl<T> DerefMut for BufferMut<T> {
522 #[inline]
523 fn deref_mut(&mut self) -> &mut Self::Target {
524 self.as_mut_slice()
525 }
526}
527
528impl<T> AsRef<[T]> for BufferMut<T> {
529 #[inline]
530 fn as_ref(&self) -> &[T] {
531 self.as_slice()
532 }
533}
534
535impl<T> AsMut<[T]> for BufferMut<T> {
536 #[inline]
537 fn as_mut(&mut self) -> &mut [T] {
538 self.as_mut_slice()
539 }
540}
541
542impl<T> BufferMut<T> {
543 fn extend_iter(&mut self, mut iter: impl Iterator<Item = T>) {
548 let (lower_bound, _) = iter.size_hint();
551
552 self.reserve(lower_bound);
555
556 let unwritten = self.capacity() - self.len();
557
558 let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
560 let mut dst: *mut T = begin.cast_mut();
561
562 for _ in 0..unwritten {
564 let Some(item) = iter.next() else {
565 break;
567 };
568
569 unsafe { dst.write(item) };
572
573 unsafe { dst = dst.add(1) };
578 }
579
580 let items_written = unsafe { dst.offset_from_unsigned(begin) };
583 let length = self.len() + items_written;
584
585 unsafe { self.set_len(length) };
587
588 iter.for_each(|item| self.push(item));
591 }
592
593 pub fn extend_trusted<I: TrustedLen<Item = T>>(&mut self, iter: I) {
598 let (_, upper_bound) = iter.size_hint();
601 self.reserve(
602 upper_bound
603 .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
604 );
605
606 let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
608 let mut dst: *mut T = begin.cast_mut();
609
610 iter.for_each(|item| {
611 unsafe { dst.write(item) };
614
615 unsafe { dst = dst.add(1) };
620 });
621
622 let items_written = unsafe { dst.offset_from_unsigned(begin) };
625 let length = self.len() + items_written;
626
627 unsafe { self.set_len(length) };
629 }
630
631 pub fn from_trusted_len_iter<I>(iter: I) -> Self
635 where
636 I: TrustedLen<Item = T>,
637 {
638 let (_, upper_bound) = iter.size_hint();
639 let mut buffer = Self::with_capacity(
640 upper_bound
641 .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
642 );
643
644 buffer.extend_trusted(iter);
645 buffer
646 }
647}
648
649impl<T> Extend<T> for BufferMut<T> {
650 #[inline]
651 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
652 self.extend_iter(iter.into_iter())
653 }
654}
655
656impl<'a, T> Extend<&'a T> for BufferMut<T>
657where
658 T: Copy + 'a,
659{
660 #[inline]
661 fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
662 self.extend_iter(iter.into_iter().copied())
663 }
664}
665
666impl<T> FromIterator<T> for BufferMut<T> {
667 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
668 let mut buffer = Self::with_capacity(0);
670 buffer.extend(iter);
671 buffer
672 }
673}
674
675impl Buf for ByteBufferMut {
676 fn remaining(&self) -> usize {
677 self.len()
678 }
679
680 fn chunk(&self) -> &[u8] {
681 self.as_slice()
682 }
683
684 fn advance(&mut self, cnt: usize) {
685 if !cnt.is_multiple_of(*self.alignment) {
686 vortex_panic!(
687 "Cannot advance buffer by {} items, resulting alignment is not {}",
688 cnt,
689 self.alignment
690 );
691 }
692 self.bytes.advance(cnt);
693 self.length -= cnt;
694 }
695}
696
697unsafe impl BufMut for ByteBufferMut {
701 #[inline]
702 fn remaining_mut(&self) -> usize {
703 usize::MAX - self.len()
704 }
705
706 #[inline]
707 unsafe fn advance_mut(&mut self, cnt: usize) {
708 if !cnt.is_multiple_of(*self.alignment) {
709 vortex_panic!(
710 "Cannot advance buffer by {} items, resulting alignment is not {}",
711 cnt,
712 self.alignment
713 );
714 }
715 unsafe { self.bytes.advance_mut(cnt) };
716 self.length -= cnt;
717 }
718
719 #[inline]
720 fn chunk_mut(&mut self) -> &mut UninitSlice {
721 self.bytes.chunk_mut()
722 }
723
724 fn put<T: Buf>(&mut self, mut src: T)
725 where
726 Self: Sized,
727 {
728 while src.has_remaining() {
729 let chunk = src.chunk();
730 self.extend_from_slice(chunk);
731 src.advance(chunk.len());
732 }
733 }
734
735 #[inline]
736 fn put_slice(&mut self, src: &[u8]) {
737 self.extend_from_slice(src);
738 }
739
740 #[inline]
741 fn put_bytes(&mut self, val: u8, cnt: usize) {
742 self.push_n(val, cnt)
743 }
744}
745
746trait AlignedBytesMut {
748 fn align_empty(&mut self, alignment: Alignment);
754}
755
756impl AlignedBytesMut for BytesMut {
757 fn align_empty(&mut self, alignment: Alignment) {
758 if !self.is_empty() {
760 vortex_panic!("ByteBufferMut must be empty");
761 }
762
763 let padding = self.as_ptr().align_offset(*alignment);
764 self.capacity()
765 .checked_sub(padding)
766 .vortex_expect("Not enough capacity to align buffer");
767
768 unsafe { self.set_len(padding) };
771 self.advance(padding);
772 }
773}
774
775impl Write for ByteBufferMut {
776 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
777 self.extend_from_slice(buf);
778 Ok(buf.len())
779 }
780
781 fn flush(&mut self) -> std::io::Result<()> {
782 Ok(())
783 }
784}
785
786#[cfg(test)]
787mod test {
788 use bytes::Buf;
789 use bytes::BufMut;
790
791 use crate::Alignment;
792 use crate::BufferMut;
793 use crate::ByteBufferMut;
794 use crate::buffer_mut;
795
796 #[test]
797 fn capacity() {
798 let mut n = 57;
799 let mut buf = BufferMut::<i32>::with_capacity_aligned(n, Alignment::new(1024));
800 assert!(buf.capacity() >= 57);
801
802 while n > 0 {
803 buf.push(0);
804 assert!(buf.capacity() >= n);
805 n -= 1
806 }
807
808 assert_eq!(buf.alignment(), Alignment::new(1024));
809 }
810
811 #[test]
812 fn from_iter() {
813 let buf = BufferMut::from_iter([0, 10, 20, 30]);
814 assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
815 }
816
817 #[test]
818 fn extend() {
819 let mut buf = BufferMut::empty();
820 buf.extend([0i32, 10, 20, 30]);
821 buf.extend([40, 50, 60]);
822 assert_eq!(buf.as_slice(), &[0, 10, 20, 30, 40, 50, 60]);
823 }
824
825 #[test]
826 fn push() {
827 let mut buf = BufferMut::empty();
828 buf.push(1);
829 buf.push(2);
830 buf.push(3);
831 assert_eq!(buf.as_slice(), &[1, 2, 3]);
832 }
833
834 #[test]
835 fn push_n() {
836 let mut buf = BufferMut::empty();
837 buf.push_n(0, 100);
838 assert_eq!(buf.as_slice(), &[0; 100]);
839 }
840
841 #[test]
842 fn as_mut() {
843 let mut buf = buffer_mut![0, 1, 2];
844 buf[1] = 0;
846 buf.as_mut()[2] = 0;
848 assert_eq!(buf.as_slice(), &[0, 0, 0]);
849 }
850
851 #[test]
852 fn map_each() {
853 let buf = buffer_mut![0i32, 1, 2];
854 let buf = buf.map_each_in_place(|i| (i + 1) as u32);
856 assert_eq!(buf.as_slice(), &[1u32, 2, 3]);
857 }
858
859 #[test]
860 fn bytes_buf() {
861 let mut buf = ByteBufferMut::copy_from("helloworld".as_bytes());
862 assert_eq!(buf.remaining(), 10);
863 assert_eq!(buf.chunk(), b"helloworld");
864
865 buf.advance(5);
866 assert_eq!(buf.remaining(), 5);
867 assert_eq!(buf.as_slice(), b"world");
868 assert_eq!(buf.chunk(), b"world");
869 }
870
871 #[test]
872 fn bytes_buf_mut() {
873 let mut buf = ByteBufferMut::copy_from("hello".as_bytes());
874 assert_eq!(BufMut::remaining_mut(&buf), usize::MAX - 5);
875
876 buf.put_slice(b"world");
877 assert_eq!(buf.as_slice(), b"helloworld");
878 }
879
880 #[test]
881 fn buffer_mut_zeroed() {
882 const LEN: usize = 17;
883
884 let mut buf = BufferMut::<u32>::zeroed(LEN);
885
886 assert_eq!(buf.as_ptr().align_offset(*Alignment::of::<u32>()), 0);
887 assert_eq!(buf.as_slice(), &[0; LEN]);
888
889 buf[3] = 7;
890 assert_eq!(buf.as_slice()[3], 7);
891 }
892
893 #[test]
894 fn buffer_mut_zeroed_aligned() {
895 const LEN: usize = 17;
896 let alignment = Alignment::new(64);
897
898 let mut buf = BufferMut::<u32>::zeroed_aligned(LEN, alignment);
899
900 assert_eq!(buf.as_ptr().align_offset(*alignment), 0);
901 assert_eq!(buf.as_slice(), &[0; LEN]);
902
903 buf[3] = 7;
904 assert_eq!(buf.as_slice()[3], 7);
905 }
906}