1use core::mem::MaybeUninit;
5use std::any::type_name;
6use std::fmt::Debug;
7use std::fmt::Formatter;
8use std::io::Write;
9use std::ops::Deref;
10use std::ops::DerefMut;
11
12use bytes::Buf;
13use bytes::BufMut;
14use bytes::BytesMut;
15use bytes::buf::UninitSlice;
16use vortex_error::VortexExpect;
17use vortex_error::vortex_panic;
18
19use crate::Alignment;
20use crate::Buffer;
21use crate::ByteBufferMut;
22use crate::debug::TruncatedDebug;
23use crate::trusted_len::TrustedLen;
24
25#[derive(PartialEq, Eq)]
27pub struct BufferMut<T> {
28 pub(crate) bytes: BytesMut,
29 pub(crate) length: usize,
30 pub(crate) alignment: Alignment,
31 pub(crate) _marker: std::marker::PhantomData<T>,
32}
33
34impl<T> BufferMut<T> {
35 pub fn with_capacity(capacity: usize) -> Self {
37 Self::with_capacity_aligned(capacity, Alignment::of::<T>())
38 }
39
40 pub fn with_capacity_aligned(capacity: usize, alignment: Alignment) -> Self {
42 if !alignment.is_aligned_to(Alignment::of::<T>()) {
43 vortex_panic!(
44 "Alignment {} must align to the scalar type's alignment {}",
45 alignment,
46 align_of::<T>()
47 );
48 }
49
50 let mut bytes = BytesMut::with_capacity((capacity * size_of::<T>()) + *alignment);
51 bytes.align_empty(alignment);
52
53 Self {
54 bytes,
55 length: 0,
56 alignment,
57 _marker: Default::default(),
58 }
59 }
60
61 pub fn zeroed(len: usize) -> Self {
63 Self::zeroed_aligned(len, Alignment::of::<T>())
64 }
65
66 pub fn zeroed_aligned(len: usize, alignment: Alignment) -> Self {
68 let mut bytes = BytesMut::zeroed((len * size_of::<T>()) + *alignment);
69 bytes.advance(bytes.as_ptr().align_offset(*alignment));
70 unsafe { bytes.set_len(len * size_of::<T>()) };
71 Self {
72 bytes,
73 length: len,
74 alignment,
75 _marker: Default::default(),
76 }
77 }
78
79 pub fn empty() -> Self {
81 Self::empty_aligned(Alignment::of::<T>())
82 }
83
84 pub fn empty_aligned(alignment: Alignment) -> Self {
86 BufferMut::with_capacity_aligned(0, alignment)
87 }
88
89 pub fn full(item: T, len: usize) -> Self
91 where
92 T: Copy,
93 {
94 let mut buffer = BufferMut::<T>::with_capacity(len);
95 buffer.push_n(item, len);
96 buffer
97 }
98
99 pub fn copy_from(other: impl AsRef<[T]>) -> Self {
101 Self::copy_from_aligned(other, Alignment::of::<T>())
102 }
103
104 pub fn copy_from_aligned(other: impl AsRef<[T]>, alignment: Alignment) -> Self {
110 if !alignment.is_aligned_to(Alignment::of::<T>()) {
111 vortex_panic!("Given alignment is not aligned to type T")
112 }
113 let other = other.as_ref();
114 let mut buffer = Self::with_capacity_aligned(other.len(), alignment);
115 buffer.extend_from_slice(other);
116 debug_assert_eq!(buffer.alignment(), alignment);
117 buffer
118 }
119
120 #[inline(always)]
122 pub fn alignment(&self) -> Alignment {
123 self.alignment
124 }
125
126 #[inline(always)]
128 pub fn len(&self) -> usize {
129 debug_assert_eq!(self.length, self.bytes.len() / size_of::<T>());
130 self.length
131 }
132
133 #[inline(always)]
135 pub fn is_empty(&self) -> bool {
136 self.length == 0
137 }
138
139 #[inline]
141 pub fn capacity(&self) -> usize {
142 self.bytes.capacity() / size_of::<T>()
143 }
144
145 #[inline]
147 pub fn as_slice(&self) -> &[T] {
148 let raw_slice = self.bytes.as_ref();
149 unsafe { std::slice::from_raw_parts(raw_slice.as_ptr().cast(), self.length) }
151 }
152
153 #[inline]
155 pub fn as_mut_slice(&mut self) -> &mut [T] {
156 let raw_slice = self.bytes.as_mut();
157 unsafe { std::slice::from_raw_parts_mut(raw_slice.as_mut_ptr().cast(), self.length) }
159 }
160
161 #[inline]
163 pub fn clear(&mut self) {
164 unsafe { self.bytes.set_len(0) }
165 self.length = 0;
166 }
167
168 #[inline]
176 pub fn truncate(&mut self, len: usize) {
177 if len <= self.len() {
178 unsafe { self.set_len(len) };
180 }
181 }
182
183 #[inline]
185 pub fn reserve(&mut self, additional: usize) {
186 let additional_bytes = additional * size_of::<T>();
187 if additional_bytes <= self.bytes.capacity() - self.bytes.len() {
188 return;
190 }
191
192 self.reserve_allocate(additional);
194 }
195
196 fn reserve_allocate(&mut self, additional: usize) {
199 let new_capacity: usize = ((self.length + additional) * size_of::<T>()) + *self.alignment;
200 let new_capacity = new_capacity.max(self.bytes.capacity() * 2);
202
203 let mut bytes = BytesMut::with_capacity(new_capacity);
204 bytes.align_empty(self.alignment);
205 bytes.extend_from_slice(&self.bytes);
206 self.bytes = bytes;
207 }
208
209 #[inline]
241 pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] {
242 let dst = self.bytes.spare_capacity_mut().as_mut_ptr();
243 unsafe {
244 std::slice::from_raw_parts_mut(
245 dst as *mut MaybeUninit<T>,
246 self.capacity() - self.length,
247 )
248 }
249 }
250
251 #[inline]
260 pub unsafe fn set_len(&mut self, len: usize) {
261 debug_assert!(len <= self.capacity());
262 unsafe { self.bytes.set_len(len * size_of::<T>()) };
263 self.length = len;
264 }
265
266 #[inline]
268 pub fn push(&mut self, value: T) {
269 self.reserve(1);
270 unsafe { self.push_unchecked(value) }
271 }
272
273 #[inline]
279 pub unsafe fn push_unchecked(&mut self, item: T) {
280 unsafe {
282 let dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
283 dst.write(item);
284 self.bytes.set_len(self.bytes.len() + size_of::<T>())
285 }
286 self.length += 1;
287 }
288
289 #[inline]
293 pub fn push_n(&mut self, item: T, n: usize)
294 where
295 T: Copy,
296 {
297 self.reserve(n);
298 unsafe { self.push_n_unchecked(item, n) }
299 }
300
301 #[inline]
307 pub unsafe fn push_n_unchecked(&mut self, item: T, n: usize)
308 where
309 T: Copy,
310 {
311 let mut dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
312 unsafe {
314 let end = dst.add(n);
315 while dst < end {
316 dst.write(item);
317 dst = dst.add(1);
318 }
319 self.bytes.set_len(self.bytes.len() + (n * size_of::<T>()));
320 }
321 self.length += n;
322 }
323
324 #[inline]
337 pub fn extend_from_slice(&mut self, slice: &[T]) {
338 self.reserve(slice.len());
339 let raw_slice =
340 unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
341 self.bytes.extend_from_slice(raw_slice);
342 self.length += slice.len();
343 }
344
345 pub fn split_off(&mut self, at: usize) -> Self {
355 if at > self.capacity() {
356 vortex_panic!("Cannot split buffer of capacity {} at {}", self.len(), at);
357 }
358
359 let bytes_at = at * size_of::<T>();
360 if !bytes_at.is_multiple_of(*self.alignment) {
361 vortex_panic!(
362 "Cannot split buffer at {}, resulting alignment is not {}",
363 at,
364 self.alignment
365 );
366 }
367
368 let new_bytes = self.bytes.split_off(bytes_at);
369
370 let new_length = self.length.saturating_sub(at);
372 self.length = self.length.min(at);
373
374 BufferMut {
375 bytes: new_bytes,
376 length: new_length,
377 alignment: self.alignment,
378 _marker: Default::default(),
379 }
380 }
381
382 pub fn unsplit(&mut self, other: Self) {
390 if self.alignment != other.alignment {
391 vortex_panic!(
392 "Cannot unsplit buffers with different alignments: {} and {}",
393 self.alignment,
394 other.alignment
395 );
396 }
397 self.bytes.unsplit(other.bytes);
398 self.length += other.length;
399 }
400
401 pub fn into_byte_buffer(self) -> ByteBufferMut {
403 ByteBufferMut {
404 bytes: self.bytes,
405 length: self.length * size_of::<T>(),
406 alignment: self.alignment,
407 _marker: Default::default(),
408 }
409 }
410
411 pub fn freeze(self) -> Buffer<T> {
413 Buffer {
414 bytes: self.bytes.freeze(),
415 length: self.length,
416 alignment: self.alignment,
417 _marker: Default::default(),
418 }
419 }
420
421 pub fn map_each_in_place<R, F>(self, mut f: F) -> BufferMut<R>
423 where
424 T: Copy,
425 F: FnMut(T) -> R,
426 {
427 assert_eq!(
428 size_of::<T>(),
429 size_of::<R>(),
430 "Size of T and R do not match"
431 );
432 let mut buf: BufferMut<R> = unsafe { std::mem::transmute(self) };
434 buf.iter_mut()
435 .for_each(|item| *item = f(unsafe { std::mem::transmute_copy(item) }));
436 buf
437 }
438
439 pub fn aligned(self, alignment: Alignment) -> Self {
445 if self.as_ptr().align_offset(*alignment) == 0 {
446 Self {
447 bytes: self.bytes,
448 length: self.length,
449 alignment,
450 _marker: std::marker::PhantomData,
451 }
452 } else {
453 Self::copy_from_aligned(self, alignment)
454 }
455 }
456
457 pub unsafe fn transmute<U>(self) -> BufferMut<U> {
469 assert_eq!(size_of::<T>(), size_of::<U>(), "Buffer type size mismatch");
470 assert_eq!(
471 align_of::<T>(),
472 align_of::<U>(),
473 "Buffer type alignment mismatch"
474 );
475
476 BufferMut {
477 bytes: self.bytes,
478 length: self.length,
479 alignment: self.alignment,
480 _marker: std::marker::PhantomData,
481 }
482 }
483}
484
485impl<T> Clone for BufferMut<T> {
486 fn clone(&self) -> Self {
487 let mut buffer = BufferMut::<T>::with_capacity_aligned(self.capacity(), self.alignment);
490 buffer.extend_from_slice(self.as_slice());
491 buffer
492 }
493}
494
495impl<T: Debug> Debug for BufferMut<T> {
496 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
497 f.debug_struct(&format!("BufferMut<{}>", type_name::<T>()))
498 .field("length", &self.length)
499 .field("alignment", &self.alignment)
500 .field("as_slice", &TruncatedDebug(self.as_slice()))
501 .finish()
502 }
503}
504
505impl<T> Default for BufferMut<T> {
506 fn default() -> Self {
507 Self::empty()
508 }
509}
510
511impl<T> Deref for BufferMut<T> {
512 type Target = [T];
513
514 #[inline]
515 fn deref(&self) -> &Self::Target {
516 self.as_slice()
517 }
518}
519
520impl<T> DerefMut for BufferMut<T> {
521 #[inline]
522 fn deref_mut(&mut self) -> &mut Self::Target {
523 self.as_mut_slice()
524 }
525}
526
527impl<T> AsRef<[T]> for BufferMut<T> {
528 #[inline]
529 fn as_ref(&self) -> &[T] {
530 self.as_slice()
531 }
532}
533
534impl<T> AsMut<[T]> for BufferMut<T> {
535 #[inline]
536 fn as_mut(&mut self) -> &mut [T] {
537 self.as_mut_slice()
538 }
539}
540
541impl<T> BufferMut<T> {
542 fn extend_iter(&mut self, mut iter: impl Iterator<Item = T>) {
547 let (lower_bound, _) = iter.size_hint();
550
551 self.reserve(lower_bound);
554
555 let unwritten = self.capacity() - self.len();
556
557 let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
559 let mut dst: *mut T = begin.cast_mut();
560
561 for _ in 0..unwritten {
563 let Some(item) = iter.next() else {
564 break;
566 };
567
568 unsafe { dst.write(item) };
571
572 unsafe { dst = dst.add(1) };
577 }
578
579 let items_written = unsafe { dst.offset_from_unsigned(begin) };
582 let length = self.len() + items_written;
583
584 unsafe { self.set_len(length) };
586
587 iter.for_each(|item| self.push(item));
590 }
591
592 pub fn extend_trusted<I: TrustedLen<Item = T>>(&mut self, iter: I) {
597 let (_, upper_bound) = iter.size_hint();
600 self.reserve(
601 upper_bound
602 .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
603 );
604
605 let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
607 let mut dst: *mut T = begin.cast_mut();
608
609 iter.for_each(|item| {
610 unsafe { dst.write(item) };
613
614 unsafe { dst = dst.add(1) };
619 });
620
621 let items_written = unsafe { dst.offset_from_unsigned(begin) };
624 let length = self.len() + items_written;
625
626 unsafe { self.set_len(length) };
628 }
629
630 pub fn from_trusted_len_iter<I>(iter: I) -> Self
634 where
635 I: TrustedLen<Item = T>,
636 {
637 let (_, upper_bound) = iter.size_hint();
638 let mut buffer = Self::with_capacity(
639 upper_bound
640 .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
641 );
642
643 buffer.extend_trusted(iter);
644 buffer
645 }
646}
647
648impl<T> Extend<T> for BufferMut<T> {
649 #[inline]
650 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
651 self.extend_iter(iter.into_iter())
652 }
653}
654
655impl<'a, T> Extend<&'a T> for BufferMut<T>
656where
657 T: Copy + 'a,
658{
659 #[inline]
660 fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
661 self.extend_iter(iter.into_iter().copied())
662 }
663}
664
665impl<T> FromIterator<T> for BufferMut<T> {
666 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
667 let mut buffer = Self::with_capacity(0);
669 buffer.extend(iter);
670 buffer
671 }
672}
673
674impl Buf for ByteBufferMut {
675 fn remaining(&self) -> usize {
676 self.len()
677 }
678
679 fn chunk(&self) -> &[u8] {
680 self.as_slice()
681 }
682
683 fn advance(&mut self, cnt: usize) {
684 if !cnt.is_multiple_of(*self.alignment) {
685 vortex_panic!(
686 "Cannot advance buffer by {} items, resulting alignment is not {}",
687 cnt,
688 self.alignment
689 );
690 }
691 self.bytes.advance(cnt);
692 self.length -= cnt;
693 }
694}
695
696unsafe impl BufMut for ByteBufferMut {
700 #[inline]
701 fn remaining_mut(&self) -> usize {
702 usize::MAX - self.len()
703 }
704
705 #[inline]
706 unsafe fn advance_mut(&mut self, cnt: usize) {
707 if !cnt.is_multiple_of(*self.alignment) {
708 vortex_panic!(
709 "Cannot advance buffer by {} items, resulting alignment is not {}",
710 cnt,
711 self.alignment
712 );
713 }
714 unsafe { self.bytes.advance_mut(cnt) };
715 self.length -= cnt;
716 }
717
718 #[inline]
719 fn chunk_mut(&mut self) -> &mut UninitSlice {
720 self.bytes.chunk_mut()
721 }
722
723 fn put<T: Buf>(&mut self, mut src: T)
724 where
725 Self: Sized,
726 {
727 while src.has_remaining() {
728 let chunk = src.chunk();
729 self.extend_from_slice(chunk);
730 src.advance(chunk.len());
731 }
732 }
733
734 #[inline]
735 fn put_slice(&mut self, src: &[u8]) {
736 self.extend_from_slice(src);
737 }
738
739 #[inline]
740 fn put_bytes(&mut self, val: u8, cnt: usize) {
741 self.push_n(val, cnt)
742 }
743}
744
745trait AlignedBytesMut {
747 fn align_empty(&mut self, alignment: Alignment);
753}
754
755impl AlignedBytesMut for BytesMut {
756 fn align_empty(&mut self, alignment: Alignment) {
757 if !self.is_empty() {
759 vortex_panic!("ByteBufferMut must be empty");
760 }
761
762 let padding = self.as_ptr().align_offset(*alignment);
763 self.capacity()
764 .checked_sub(padding)
765 .vortex_expect("Not enough capacity to align buffer");
766
767 unsafe { self.set_len(padding) };
770 self.advance(padding);
771 }
772}
773
774impl Write for ByteBufferMut {
775 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
776 self.extend_from_slice(buf);
777 Ok(buf.len())
778 }
779
780 fn flush(&mut self) -> std::io::Result<()> {
781 Ok(())
782 }
783}
784
785#[cfg(test)]
786mod test {
787 use bytes::Buf;
788 use bytes::BufMut;
789
790 use crate::Alignment;
791 use crate::BufferMut;
792 use crate::ByteBufferMut;
793 use crate::buffer_mut;
794
795 #[test]
796 fn capacity() {
797 let mut n = 57;
798 let mut buf = BufferMut::<i32>::with_capacity_aligned(n, Alignment::new(1024));
799 assert!(buf.capacity() >= 57);
800
801 while n > 0 {
802 buf.push(0);
803 assert!(buf.capacity() >= n);
804 n -= 1
805 }
806
807 assert_eq!(buf.alignment(), Alignment::new(1024));
808 }
809
810 #[test]
811 fn from_iter() {
812 let buf = BufferMut::from_iter([0, 10, 20, 30]);
813 assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
814 }
815
816 #[test]
817 fn extend() {
818 let mut buf = BufferMut::empty();
819 buf.extend([0i32, 10, 20, 30]);
820 buf.extend([40, 50, 60]);
821 assert_eq!(buf.as_slice(), &[0, 10, 20, 30, 40, 50, 60]);
822 }
823
824 #[test]
825 fn push() {
826 let mut buf = BufferMut::empty();
827 buf.push(1);
828 buf.push(2);
829 buf.push(3);
830 assert_eq!(buf.as_slice(), &[1, 2, 3]);
831 }
832
833 #[test]
834 fn push_n() {
835 let mut buf = BufferMut::empty();
836 buf.push_n(0, 100);
837 assert_eq!(buf.as_slice(), &[0; 100]);
838 }
839
840 #[test]
841 fn as_mut() {
842 let mut buf = buffer_mut![0, 1, 2];
843 buf[1] = 0;
845 buf.as_mut()[2] = 0;
847 assert_eq!(buf.as_slice(), &[0, 0, 0]);
848 }
849
850 #[test]
851 fn map_each() {
852 let buf = buffer_mut![0i32, 1, 2];
853 let buf = buf.map_each_in_place(|i| (i + 1) as u32);
855 assert_eq!(buf.as_slice(), &[1u32, 2, 3]);
856 }
857
858 #[test]
859 fn bytes_buf() {
860 let mut buf = ByteBufferMut::copy_from("helloworld".as_bytes());
861 assert_eq!(buf.remaining(), 10);
862 assert_eq!(buf.chunk(), b"helloworld");
863
864 Buf::advance(&mut buf, 5);
865 assert_eq!(buf.remaining(), 5);
866 assert_eq!(buf.as_slice(), b"world");
867 assert_eq!(buf.chunk(), b"world");
868 }
869
870 #[test]
871 fn bytes_buf_mut() {
872 let mut buf = ByteBufferMut::copy_from("hello".as_bytes());
873 assert_eq!(BufMut::remaining_mut(&buf), usize::MAX - 5);
874
875 BufMut::put_slice(&mut buf, b"world");
876 assert_eq!(buf.as_slice(), b"helloworld");
877 }
878}