1use core::mem::MaybeUninit;
2use std::any::type_name;
3use std::fmt::{Debug, Formatter};
4use std::io::Write;
5use std::ops::{Deref, DerefMut};
6
7use bytes::buf::UninitSlice;
8use bytes::{Buf, BufMut, BytesMut};
9use vortex_error::{VortexExpect, vortex_panic};
10
11use crate::debug::TruncatedDebug;
12use crate::spec_extend::SpecExtend;
13use crate::{Alignment, Buffer, ByteBufferMut};
14
15#[derive(PartialEq, Eq)]
17pub struct BufferMut<T> {
18 pub(crate) bytes: BytesMut,
19 pub(crate) length: usize,
20 pub(crate) alignment: Alignment,
21 pub(crate) _marker: std::marker::PhantomData<T>,
22}
23
24impl<T> BufferMut<T> {
25 pub fn with_capacity(capacity: usize) -> Self {
27 Self::with_capacity_aligned(capacity, Alignment::of::<T>())
28 }
29
30 pub fn with_capacity_aligned(capacity: usize, alignment: Alignment) -> Self {
32 if !alignment.is_aligned_to(Alignment::of::<T>()) {
33 vortex_panic!(
34 "Alignment {} must align to the scalar type's alignment {}",
35 alignment,
36 align_of::<T>()
37 );
38 }
39
40 let mut bytes = BytesMut::with_capacity((capacity * size_of::<T>()) + *alignment);
41 bytes.align_empty(alignment);
42
43 Self {
44 bytes,
45 length: 0,
46 alignment,
47 _marker: Default::default(),
48 }
49 }
50
51 pub fn zeroed(len: usize) -> Self {
53 Self::zeroed_aligned(len, Alignment::of::<T>())
54 }
55
56 pub fn zeroed_aligned(len: usize, alignment: Alignment) -> Self {
58 let mut bytes = BytesMut::zeroed((len * size_of::<T>()) + *alignment);
59 bytes.advance(bytes.as_ptr().align_offset(*alignment));
60 unsafe { bytes.set_len(len * size_of::<T>()) };
61 Self {
62 bytes,
63 length: len,
64 alignment,
65 _marker: Default::default(),
66 }
67 }
68
69 pub fn empty() -> Self {
71 Self::empty_aligned(Alignment::of::<T>())
72 }
73
74 pub fn empty_aligned(alignment: Alignment) -> Self {
76 BufferMut::with_capacity_aligned(0, alignment)
77 }
78
79 pub fn full(item: T, len: usize) -> Self
81 where
82 T: Copy,
83 {
84 let mut buffer = BufferMut::<T>::with_capacity(len);
85 buffer.push_n(item, len);
86 buffer
87 }
88
89 pub fn copy_from(other: impl AsRef<[T]>) -> Self {
91 Self::copy_from_aligned(other, Alignment::of::<T>())
92 }
93
94 pub fn copy_from_aligned(other: impl AsRef<[T]>, alignment: Alignment) -> Self {
100 if !alignment.is_aligned_to(Alignment::of::<T>()) {
101 vortex_panic!("Given alignment is not aligned to type T")
102 }
103 let other = other.as_ref();
104 let mut buffer = Self::with_capacity_aligned(other.len(), alignment);
105 buffer.extend_from_slice(other);
106 debug_assert_eq!(buffer.alignment(), alignment);
107 buffer
108 }
109
110 #[inline(always)]
112 pub fn alignment(&self) -> Alignment {
113 self.alignment
114 }
115
116 #[inline(always)]
118 pub fn len(&self) -> usize {
119 debug_assert_eq!(self.length, self.bytes.len() / size_of::<T>());
120 self.length
121 }
122
123 #[inline(always)]
125 pub fn is_empty(&self) -> bool {
126 self.length == 0
127 }
128
129 #[inline]
131 pub fn capacity(&self) -> usize {
132 self.bytes.capacity() / size_of::<T>()
133 }
134
135 #[inline]
137 pub fn as_slice(&self) -> &[T] {
138 let raw_slice = self.bytes.as_ref();
139 unsafe { std::slice::from_raw_parts(raw_slice.as_ptr().cast(), self.length) }
141 }
142
143 #[inline]
145 pub fn as_mut_slice(&mut self) -> &mut [T] {
146 let raw_slice = self.bytes.as_mut();
147 unsafe { std::slice::from_raw_parts_mut(raw_slice.as_mut_ptr().cast(), self.length) }
149 }
150
151 #[inline]
153 pub fn clear(&mut self) {
154 unsafe { self.bytes.set_len(0) }
155 self.length = 0;
156 }
157
158 #[inline]
166 pub fn truncate(&mut self, len: usize) {
167 if len <= self.len() {
168 unsafe { self.set_len(len) };
170 }
171 }
172
173 #[inline]
175 pub fn reserve(&mut self, additional: usize) {
176 let additional_bytes = additional * size_of::<T>();
177 if additional_bytes <= self.bytes.capacity() - self.bytes.len() {
178 return;
180 }
181
182 self.reserve_allocate(additional);
184 }
185
186 fn reserve_allocate(&mut self, additional: usize) {
189 let new_capacity: usize = ((self.length + additional) * size_of::<T>()) + *self.alignment;
190 let new_capacity = new_capacity.max(self.bytes.capacity() * 2);
192
193 let mut bytes = BytesMut::with_capacity(new_capacity);
194 bytes.align_empty(self.alignment);
195 bytes.extend_from_slice(&self.bytes);
196 self.bytes = bytes;
197 }
198
199 #[inline]
231 pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] {
232 let dst = self.bytes.spare_capacity_mut().as_mut_ptr();
233 unsafe {
234 std::slice::from_raw_parts_mut(
235 dst as *mut MaybeUninit<T>,
236 self.capacity() - self.length,
237 )
238 }
239 }
240
241 #[inline]
244 pub unsafe fn set_len(&mut self, len: usize) {
245 unsafe { self.bytes.set_len(len * size_of::<T>()) };
246 self.length = len;
247 }
248
249 #[inline]
251 pub fn push(&mut self, value: T) {
252 self.reserve(1);
253 unsafe { self.push_unchecked(value) }
254 }
255
256 #[inline]
262 pub unsafe fn push_unchecked(&mut self, item: T) {
263 unsafe {
265 let dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
266 dst.write(item);
267 self.bytes.set_len(self.bytes.len() + size_of::<T>())
268 }
269 self.length += 1;
270 }
271
272 #[inline]
276 pub fn push_n(&mut self, item: T, n: usize)
277 where
278 T: Copy,
279 {
280 self.reserve(n);
281 unsafe { self.push_n_unchecked(item, n) }
282 }
283
284 #[inline]
290 pub unsafe fn push_n_unchecked(&mut self, item: T, n: usize)
291 where
292 T: Copy,
293 {
294 let mut dst: *mut T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
295 unsafe {
297 let end = dst.add(n);
298 while dst < end {
299 dst.write(item);
300 dst = dst.add(1);
301 }
302 self.bytes.set_len(self.bytes.len() + (n * size_of::<T>()));
303 }
304 self.length += n;
305 }
306
307 #[inline]
320 pub fn extend_from_slice(&mut self, slice: &[T]) {
321 self.reserve(slice.len());
322 let raw_slice =
323 unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
324 self.bytes.extend_from_slice(raw_slice);
325 self.length += slice.len();
326 }
327
328 pub fn freeze(self) -> Buffer<T> {
330 Buffer {
331 bytes: self.bytes.freeze(),
332 length: self.length,
333 alignment: self.alignment,
334 _marker: Default::default(),
335 }
336 }
337
338 pub fn map_each<R, F>(self, mut f: F) -> BufferMut<R>
340 where
341 T: Copy,
342 F: FnMut(T) -> R,
343 {
344 assert_eq!(
345 size_of::<T>(),
346 size_of::<R>(),
347 "Size of T and R do not match"
348 );
349 let mut buf: BufferMut<R> = unsafe { std::mem::transmute(self) };
351 buf.iter_mut()
352 .for_each(|item| *item = f(unsafe { std::mem::transmute_copy(item) }));
353 buf
354 }
355
356 pub fn aligned(self, alignment: Alignment) -> Self {
358 if self.as_ptr().align_offset(*alignment) == 0 {
359 self
360 } else {
361 Self::copy_from_aligned(self, alignment)
362 }
363 }
364}
365
366impl<T> Clone for BufferMut<T> {
367 fn clone(&self) -> Self {
368 let mut buffer = BufferMut::<T>::with_capacity_aligned(self.capacity(), self.alignment);
371 buffer.extend_from_slice(self.as_slice());
372 buffer
373 }
374}
375
376impl<T: Debug> Debug for BufferMut<T> {
377 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
378 f.debug_struct(&format!("BufferMut<{}>", type_name::<T>()))
379 .field("length", &self.length)
380 .field("alignment", &self.alignment)
381 .field("as_slice", &TruncatedDebug(self.as_slice()))
382 .finish()
383 }
384}
385
386impl<T> Default for BufferMut<T> {
387 fn default() -> Self {
388 Self::empty()
389 }
390}
391
392impl<T> Deref for BufferMut<T> {
393 type Target = [T];
394
395 fn deref(&self) -> &Self::Target {
396 self.as_slice()
397 }
398}
399
400impl<T> DerefMut for BufferMut<T> {
401 fn deref_mut(&mut self) -> &mut Self::Target {
402 self.as_mut_slice()
403 }
404}
405
406impl<T> AsRef<[T]> for BufferMut<T> {
407 fn as_ref(&self) -> &[T] {
408 self.as_slice()
409 }
410}
411
412impl<T> AsMut<[T]> for BufferMut<T> {
413 fn as_mut(&mut self) -> &mut [T] {
414 self.as_mut_slice()
415 }
416}
417
418impl<T> Extend<T> for BufferMut<T> {
419 #[inline]
420 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
421 <Self as SpecExtend<T, I::IntoIter>>::spec_extend(self, iter.into_iter())
422 }
423}
424
425impl<'a, T> Extend<&'a T> for BufferMut<T>
426where
427 T: Copy + 'a,
428{
429 #[inline]
430 fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
431 self.spec_extend(iter.into_iter())
432 }
433}
434
435impl<T> FromIterator<T> for BufferMut<T> {
436 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
437 let mut buffer = Self::with_capacity(0);
439 buffer.extend(iter);
440 debug_assert_eq!(buffer.alignment(), Alignment::of::<T>());
441 buffer
442 }
443}
444
445impl Buf for ByteBufferMut {
446 fn remaining(&self) -> usize {
447 self.len()
448 }
449
450 fn chunk(&self) -> &[u8] {
451 self.as_slice()
452 }
453
454 fn advance(&mut self, cnt: usize) {
455 if !cnt.is_multiple_of(*self.alignment) {
456 vortex_panic!(
457 "Cannot advance buffer by {} items, resulting alignment is not {}",
458 cnt,
459 self.alignment
460 );
461 }
462 self.bytes.advance(cnt);
463 self.length -= cnt;
464 }
465}
466
467unsafe impl BufMut for ByteBufferMut {
471 #[inline]
472 fn remaining_mut(&self) -> usize {
473 usize::MAX - self.len()
474 }
475
476 #[inline]
477 unsafe fn advance_mut(&mut self, cnt: usize) {
478 if !cnt.is_multiple_of(*self.alignment) {
479 vortex_panic!(
480 "Cannot advance buffer by {} items, resulting alignment is not {}",
481 cnt,
482 self.alignment
483 );
484 }
485 unsafe { self.bytes.advance_mut(cnt) };
486 self.length -= cnt;
487 }
488
489 #[inline]
490 fn chunk_mut(&mut self) -> &mut UninitSlice {
491 self.bytes.chunk_mut()
492 }
493
494 fn put<T: Buf>(&mut self, mut src: T)
495 where
496 Self: Sized,
497 {
498 while src.has_remaining() {
499 let chunk = src.chunk();
500 self.extend_from_slice(chunk);
501 src.advance(chunk.len());
502 }
503 }
504
505 #[inline]
506 fn put_slice(&mut self, src: &[u8]) {
507 self.extend_from_slice(src);
508 }
509
510 #[inline]
511 fn put_bytes(&mut self, val: u8, cnt: usize) {
512 self.push_n(val, cnt)
513 }
514}
515
516trait AlignedBytesMut {
518 fn align_empty(&mut self, alignment: Alignment);
524}
525
526impl AlignedBytesMut for BytesMut {
527 fn align_empty(&mut self, alignment: Alignment) {
528 if !self.is_empty() {
529 vortex_panic!("ByteBufferMut must be empty");
530 }
531
532 let padding = self.as_ptr().align_offset(*alignment);
533 self.capacity()
534 .checked_sub(padding)
535 .vortex_expect("Not enough capacity to align buffer");
536
537 unsafe { self.set_len(padding) };
540 self.advance(padding);
541 }
542}
543
544impl Write for ByteBufferMut {
545 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
546 self.extend_from_slice(buf);
547 Ok(buf.len())
548 }
549
550 fn flush(&mut self) -> std::io::Result<()> {
551 Ok(())
552 }
553}
554
555#[cfg(test)]
556mod test {
557 use bytes::{Buf, BufMut};
558
559 use crate::{Alignment, BufferMut, ByteBufferMut, buffer_mut};
560
561 #[test]
562 fn capacity() {
563 let mut n = 57;
564 let mut buf = BufferMut::<i32>::with_capacity_aligned(n, Alignment::new(1024));
565 assert!(buf.capacity() >= 57);
566
567 while n > 0 {
568 buf.push(0);
569 assert!(buf.capacity() >= n);
570 n -= 1
571 }
572
573 assert_eq!(buf.alignment(), Alignment::new(1024));
574 }
575
576 #[test]
577 fn from_iter() {
578 let buf = BufferMut::from_iter([0, 10, 20, 30]);
579 assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
580 }
581
582 #[test]
583 fn extend() {
584 let mut buf = BufferMut::empty();
585 buf.extend([0i32, 10, 20, 30]);
586 buf.extend([40, 50, 60]);
587 assert_eq!(buf.as_slice(), &[0, 10, 20, 30, 40, 50, 60]);
588 }
589
590 #[test]
591 fn push() {
592 let mut buf = BufferMut::empty();
593 buf.push(1);
594 buf.push(2);
595 buf.push(3);
596 assert_eq!(buf.as_slice(), &[1, 2, 3]);
597 }
598
599 #[test]
600 fn push_n() {
601 let mut buf = BufferMut::empty();
602 buf.push_n(0, 100);
603 assert_eq!(buf.as_slice(), &[0; 100]);
604 }
605
606 #[test]
607 fn as_mut() {
608 let mut buf = buffer_mut![0, 1, 2];
609 buf[1] = 0;
611 buf.as_mut()[2] = 0;
613 assert_eq!(buf.as_slice(), &[0, 0, 0]);
614 }
615
616 #[test]
617 fn map_each() {
618 let buf = buffer_mut![0i32, 1, 2];
619 let buf = buf.map_each(|i| (i + 1) as u32);
621 assert_eq!(buf.as_slice(), &[1u32, 2, 3]);
622 }
623
624 #[test]
625 fn bytes_buf() {
626 let mut buf = ByteBufferMut::copy_from("helloworld".as_bytes());
627 assert_eq!(buf.remaining(), 10);
628 assert_eq!(buf.chunk(), b"helloworld");
629
630 Buf::advance(&mut buf, 5);
631 assert_eq!(buf.remaining(), 5);
632 assert_eq!(buf.as_slice(), b"world");
633 assert_eq!(buf.chunk(), b"world");
634 }
635
636 #[test]
637 fn bytes_buf_mut() {
638 let mut buf = ByteBufferMut::copy_from("hello".as_bytes());
639 assert_eq!(BufMut::remaining_mut(&buf), usize::MAX - 5);
640
641 BufMut::put_slice(&mut buf, b"world");
642 assert_eq!(buf.as_slice(), b"helloworld");
643 }
644}