1use core::mem::MaybeUninit;
2use std::any::type_name;
3use std::fmt::{Debug, Formatter};
4use std::ops::{Deref, DerefMut};
5
6use bytes::buf::UninitSlice;
7use bytes::{Buf, BufMut, BytesMut};
8use vortex_error::{VortexExpect, vortex_panic};
9
10use crate::debug::TruncatedDebug;
11use crate::spec_extend::SpecExtend;
12use crate::{Alignment, Buffer, ByteBufferMut};
13
14#[derive(PartialEq, Eq)]
16pub struct BufferMut<T> {
17 pub(crate) bytes: BytesMut,
18 pub(crate) length: usize,
19 pub(crate) alignment: Alignment,
20 pub(crate) _marker: std::marker::PhantomData<T>,
21}
22
23impl<T> BufferMut<T> {
24 pub fn with_capacity(capacity: usize) -> Self {
26 Self::with_capacity_aligned(capacity, Alignment::of::<T>())
27 }
28
29 pub fn with_capacity_aligned(capacity: usize, alignment: Alignment) -> Self {
31 if !alignment.is_aligned_to(Alignment::of::<T>()) {
32 vortex_panic!(
33 "Alignment {} must align to the scalar type's alignment {}",
34 alignment,
35 align_of::<T>()
36 );
37 }
38
39 let mut bytes = BytesMut::with_capacity((capacity * size_of::<T>()) + *alignment);
40 bytes.align_empty(alignment);
41
42 Self {
43 bytes,
44 length: 0,
45 alignment,
46 _marker: Default::default(),
47 }
48 }
49
50 pub fn zeroed(len: usize) -> Self {
52 Self::zeroed_aligned(len, Alignment::of::<T>())
53 }
54
55 pub fn zeroed_aligned(len: usize, alignment: Alignment) -> Self {
57 let mut bytes = BytesMut::zeroed((len * size_of::<T>()) + *alignment);
58 bytes.advance(bytes.as_ptr().align_offset(*alignment));
59 unsafe { bytes.set_len(len * size_of::<T>()) };
60 Self {
61 bytes,
62 length: len,
63 alignment,
64 _marker: Default::default(),
65 }
66 }
67
68 pub fn empty() -> Self {
70 Self::empty_aligned(Alignment::of::<T>())
71 }
72
73 pub fn empty_aligned(alignment: Alignment) -> Self {
75 BufferMut::with_capacity_aligned(0, alignment)
76 }
77
78 pub fn full(item: T, len: usize) -> Self
80 where
81 T: Copy,
82 {
83 let mut buffer = BufferMut::<T>::with_capacity(len);
84 buffer.push_n(item, len);
85 buffer
86 }
87
88 pub fn copy_from(other: impl AsRef<[T]>) -> Self {
90 Self::copy_from_aligned(other, Alignment::of::<T>())
91 }
92
93 pub fn copy_from_aligned(other: impl AsRef<[T]>, alignment: Alignment) -> Self {
99 if !alignment.is_aligned_to(Alignment::of::<T>()) {
100 vortex_panic!("Given alignment is not aligned to type T")
101 }
102 let other = other.as_ref();
103 let mut buffer = Self::with_capacity_aligned(other.len(), alignment);
104 buffer.extend_from_slice(other);
105 debug_assert_eq!(buffer.alignment(), alignment);
106 buffer
107 }
108
109 #[inline(always)]
111 pub fn alignment(&self) -> Alignment {
112 self.alignment
113 }
114
115 #[inline(always)]
117 pub fn len(&self) -> usize {
118 debug_assert_eq!(self.length, self.bytes.len() / size_of::<T>());
119 self.length
120 }
121
122 #[inline(always)]
124 pub fn is_empty(&self) -> bool {
125 self.length == 0
126 }
127
128 #[inline]
130 pub fn capacity(&self) -> usize {
131 self.bytes.capacity() / size_of::<T>()
132 }
133
134 #[inline]
136 pub fn as_slice(&self) -> &[T] {
137 let raw_slice = self.bytes.as_ref();
138 unsafe { std::slice::from_raw_parts(raw_slice.as_ptr().cast(), self.length) }
140 }
141
142 #[inline]
144 pub fn as_mut_slice(&mut self) -> &mut [T] {
145 let raw_slice = self.bytes.as_mut();
146 unsafe { std::slice::from_raw_parts_mut(raw_slice.as_mut_ptr().cast(), self.length) }
148 }
149
150 #[inline]
152 pub fn clear(&mut self) {
153 unsafe { self.bytes.set_len(0) }
154 self.length = 0;
155 }
156
157 #[inline]
165 pub fn truncate(&mut self, len: usize) {
166 if len <= self.len() {
167 unsafe { self.set_len(len) };
169 }
170 }
171
172 #[inline]
174 pub fn reserve(&mut self, additional: usize) {
175 let additional_bytes = additional * size_of::<T>();
176 if additional_bytes <= self.bytes.capacity() - self.bytes.len() {
177 return;
179 }
180
181 self.reserve_allocate(additional);
183 }
184
185 fn reserve_allocate(&mut self, additional: usize) {
188 let new_capacity: usize = ((self.length + additional) * size_of::<T>()) + *self.alignment;
189 let new_capacity = new_capacity.max(self.bytes.capacity() * 2);
191
192 let mut bytes = BytesMut::with_capacity(new_capacity);
193 bytes.align_empty(self.alignment);
194 bytes.extend_from_slice(&self.bytes);
195 self.bytes = bytes;
196 }
197
198 #[inline]
230 pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] {
231 let dst = self.bytes.spare_capacity_mut().as_mut_ptr();
232 unsafe {
233 std::slice::from_raw_parts_mut(
234 dst as *mut MaybeUninit<T>,
235 self.capacity() - self.length,
236 )
237 }
238 }
239
240 #[inline]
243 pub unsafe fn set_len(&mut self, len: usize) {
244 unsafe { self.bytes.set_len(len * size_of::<T>()) };
245 self.length = len;
246 }
247
248 #[inline]
250 pub fn push(&mut self, value: T) {
251 self.reserve(1);
252 unsafe { self.push_unchecked(value) }
253 }
254
255 #[inline]
261 pub unsafe fn push_unchecked(&mut self, item: T) {
262 unsafe {
264 let dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
265 dst.write(item);
266 self.bytes.set_len(self.bytes.len() + size_of::<T>())
267 }
268 self.length += 1;
269 }
270
271 #[inline]
275 pub fn push_n(&mut self, item: T, n: usize)
276 where
277 T: Copy,
278 {
279 self.reserve(n);
280 unsafe { self.push_n_unchecked(item, n) }
281 }
282
283 #[inline]
289 pub unsafe fn push_n_unchecked(&mut self, item: T, n: usize)
290 where
291 T: Copy,
292 {
293 let mut dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
294 unsafe {
296 let end = dst.add(n);
297 while dst < end {
298 dst.write(item);
299 dst = dst.add(1);
300 }
301 self.bytes.set_len(self.bytes.len() + (n * size_of::<T>()));
302 }
303 self.length += n;
304 }
305
306 #[inline]
319 pub fn extend_from_slice(&mut self, slice: &[T]) {
320 self.reserve(slice.len());
321 let raw_slice: &[u8] =
322 unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
323 self.bytes.extend_from_slice(raw_slice);
324 self.length += slice.len();
325 }
326
327 pub fn freeze(self) -> Buffer<T> {
329 Buffer {
330 bytes: self.bytes.freeze(),
331 length: self.length,
332 alignment: self.alignment,
333 _marker: Default::default(),
334 }
335 }
336
337 pub fn map_each<R, F>(self, mut f: F) -> BufferMut<R>
339 where
340 T: Copy,
341 F: FnMut(T) -> R,
342 {
343 assert_eq!(
344 size_of::<T>(),
345 size_of::<R>(),
346 "Size of T and R do not match"
347 );
348 let mut buf: BufferMut<R> = unsafe { std::mem::transmute(self) };
350 buf.iter_mut()
351 .for_each(|item| *item = f(unsafe { std::mem::transmute_copy(item) }));
352 buf
353 }
354
355 pub fn aligned(self, alignment: Alignment) -> Self {
357 if self.as_ptr().align_offset(*alignment) == 0 {
358 self
359 } else {
360 Self::copy_from_aligned(self, alignment)
361 }
362 }
363}
364
365impl<T> Clone for BufferMut<T> {
366 fn clone(&self) -> Self {
367 let mut buffer = BufferMut::<T>::with_capacity_aligned(self.capacity(), self.alignment);
370 buffer.extend_from_slice(self.as_slice());
371 buffer
372 }
373}
374
375impl<T: Debug> Debug for BufferMut<T> {
376 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
377 f.debug_struct(&format!("BufferMut<{}>", type_name::<T>()))
378 .field("length", &self.length)
379 .field("alignment", &self.alignment)
380 .field("as_slice", &TruncatedDebug(self.as_slice()))
381 .finish()
382 }
383}
384
385impl<T> Default for BufferMut<T> {
386 fn default() -> Self {
387 Self::empty()
388 }
389}
390
391impl<T> Deref for BufferMut<T> {
392 type Target = [T];
393
394 fn deref(&self) -> &Self::Target {
395 self.as_slice()
396 }
397}
398
399impl<T> DerefMut for BufferMut<T> {
400 fn deref_mut(&mut self) -> &mut Self::Target {
401 self.as_mut_slice()
402 }
403}
404
405impl<T> AsRef<[T]> for BufferMut<T> {
406 fn as_ref(&self) -> &[T] {
407 self.as_slice()
408 }
409}
410
411impl<T> AsMut<[T]> for BufferMut<T> {
412 fn as_mut(&mut self) -> &mut [T] {
413 self.as_mut_slice()
414 }
415}
416
417impl<T> Extend<T> for BufferMut<T> {
418 #[inline]
419 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
420 <Self as SpecExtend<T, I::IntoIter>>::spec_extend(self, iter.into_iter())
421 }
422}
423
424impl<'a, T> Extend<&'a T> for BufferMut<T>
425where
426 T: Copy + 'a,
427{
428 #[inline]
429 fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
430 self.spec_extend(iter.into_iter())
431 }
432}
433
434impl<T> FromIterator<T> for BufferMut<T> {
435 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
436 let mut buffer = Self::with_capacity(0);
438 buffer.extend(iter);
439 debug_assert_eq!(buffer.alignment(), Alignment::of::<T>());
440 buffer
441 }
442}
443
444impl Buf for ByteBufferMut {
445 fn remaining(&self) -> usize {
446 self.len()
447 }
448
449 fn chunk(&self) -> &[u8] {
450 self.as_slice()
451 }
452
453 fn advance(&mut self, cnt: usize) {
454 if !cnt.is_multiple_of(*self.alignment) {
455 vortex_panic!(
456 "Cannot advance buffer by {} items, resulting alignment is not {}",
457 cnt,
458 self.alignment
459 );
460 }
461 self.bytes.advance(cnt);
462 self.length -= cnt;
463 }
464}
465
466unsafe impl BufMut for ByteBufferMut {
470 #[inline]
471 fn remaining_mut(&self) -> usize {
472 usize::MAX - self.len()
473 }
474
475 #[inline]
476 unsafe fn advance_mut(&mut self, cnt: usize) {
477 if !cnt.is_multiple_of(*self.alignment) {
478 vortex_panic!(
479 "Cannot advance buffer by {} items, resulting alignment is not {}",
480 cnt,
481 self.alignment
482 );
483 }
484 unsafe { self.bytes.advance_mut(cnt) };
485 self.length -= cnt;
486 }
487
488 #[inline]
489 fn chunk_mut(&mut self) -> &mut UninitSlice {
490 self.bytes.chunk_mut()
491 }
492
493 fn put<T: Buf>(&mut self, mut src: T)
494 where
495 Self: Sized,
496 {
497 while src.has_remaining() {
498 let chunk = src.chunk();
499 self.extend_from_slice(chunk);
500 src.advance(chunk.len());
501 }
502 }
503
504 #[inline]
505 fn put_slice(&mut self, src: &[u8]) {
506 self.extend_from_slice(src);
507 }
508
509 #[inline]
510 fn put_bytes(&mut self, val: u8, cnt: usize) {
511 self.push_n(val, cnt)
512 }
513}
514
515trait AlignedBytesMut {
517 fn align_empty(&mut self, alignment: Alignment);
523}
524
525impl AlignedBytesMut for BytesMut {
526 fn align_empty(&mut self, alignment: Alignment) {
527 if !self.is_empty() {
528 vortex_panic!("ByteBufferMut must be empty");
529 }
530
531 let padding = self.as_ptr().align_offset(*alignment);
532 self.capacity()
533 .checked_sub(padding)
534 .vortex_expect("Not enough capacity to align buffer");
535
536 unsafe { self.set_len(padding) };
539 self.advance(padding);
540 }
541}
542
543#[cfg(test)]
544mod test {
545 use bytes::{Buf, BufMut};
546
547 use crate::{Alignment, BufferMut, ByteBufferMut, buffer_mut};
548
549 #[test]
550 fn capacity() {
551 let mut n = 57;
552 let mut buf = BufferMut::<i32>::with_capacity_aligned(n, Alignment::new(1024));
553 assert!(buf.capacity() >= 57);
554
555 while n > 0 {
556 buf.push(0);
557 assert!(buf.capacity() >= n);
558 n -= 1
559 }
560
561 assert_eq!(buf.alignment(), Alignment::new(1024));
562 }
563
564 #[test]
565 fn from_iter() {
566 let buf = BufferMut::from_iter([0, 10, 20, 30]);
567 assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
568 }
569
570 #[test]
571 fn extend() {
572 let mut buf = BufferMut::empty();
573 buf.extend([0i32, 10, 20, 30]);
574 buf.extend([40, 50, 60]);
575 assert_eq!(buf.as_slice(), &[0, 10, 20, 30, 40, 50, 60]);
576 }
577
578 #[test]
579 fn push() {
580 let mut buf = BufferMut::empty();
581 buf.push(1);
582 buf.push(2);
583 buf.push(3);
584 assert_eq!(buf.as_slice(), &[1, 2, 3]);
585 }
586
587 #[test]
588 fn push_n() {
589 let mut buf = BufferMut::empty();
590 buf.push_n(0, 100);
591 assert_eq!(buf.as_slice(), &[0; 100]);
592 }
593
594 #[test]
595 fn as_mut() {
596 let mut buf = buffer_mut![0, 1, 2];
597 buf[1] = 0;
599 buf.as_mut()[2] = 0;
601 assert_eq!(buf.as_slice(), &[0, 0, 0]);
602 }
603
604 #[test]
605 fn map_each() {
606 let buf = buffer_mut![0i32, 1, 2];
607 let buf = buf.map_each(|i| (i + 1) as u32);
609 assert_eq!(buf.as_slice(), &[1u32, 2, 3]);
610 }
611
612 #[test]
613 fn bytes_buf() {
614 let mut buf = ByteBufferMut::copy_from("helloworld".as_bytes());
615 assert_eq!(buf.remaining(), 10);
616 assert_eq!(buf.chunk(), b"helloworld");
617
618 Buf::advance(&mut buf, 5);
619 assert_eq!(buf.remaining(), 5);
620 assert_eq!(buf.as_slice(), b"world");
621 assert_eq!(buf.chunk(), b"world");
622 }
623
624 #[test]
625 fn bytes_buf_mut() {
626 let mut buf = ByteBufferMut::copy_from("hello".as_bytes());
627 assert_eq!(BufMut::remaining_mut(&buf), usize::MAX - 5);
628
629 BufMut::put_slice(&mut buf, b"world");
630 assert_eq!(buf.as_slice(), b"helloworld");
631 }
632}