vortex_array/pipeline/
view.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use vortex_buffer::ByteBuffer;
7use vortex_error::VortexExpect;
8
9use crate::pipeline::N;
10use crate::pipeline::bits::{BitView, BitViewMut};
11use crate::pipeline::types::{Element, VType};
12
13pub struct View<'a> {
14    /// The physical type of the vector, which defines how the elements are stored.
15    pub(super) vtype: VType,
16    /// A pointer to the allocated elements buffer.
17    /// Alignment is at least the size of the element type.
18    /// The capacity of the elements buffer is N * `size_of::<T>()` where T is the element type.
19    pub(super) elements: *const u8,
20    /// The validity mask for the vector, indicating which elements in the buffer are valid.
21    /// This value can be `None` if the expected DType is `NonNullable`.
22    // TODO: support validity
23    #[allow(dead_code)]
24    pub(super) validity: Option<BitView<'a>>,
25    // A selection mask over the elements and validity of the vector.
26    pub(super) len: usize,
27
28    /// Additional buffers of data used by the vector, such as string data.
29    #[allow(dead_code)]
30    pub(super) data: Vec<ByteBuffer>,
31
32    /// Marker defining the lifetime of the contents of the vector.
33    pub(super) _marker: std::marker::PhantomData<&'a ()>,
34}
35
36impl<'a> View<'a> {
37    #[inline(always)]
38    pub fn len(&self) -> usize {
39        self.len
40    }
41
42    pub fn is_empty(&self) -> bool {
43        self.len == 0
44    }
45
46    pub fn as_slice<T>(&self) -> &'a [T]
47    where
48        T: Element,
49    {
50        debug_assert_eq!(self.vtype, T::vtype(), "Invalid type for canonical view");
51        // SAFETY: We assume that the elements are of type T and that the view is valid.
52        unsafe { std::slice::from_raw_parts(self.elements.cast(), self.len) }
53    }
54
55    /// Re-interpret cast the vector into a new type where the element has the same width.
56    #[inline(always)]
57    pub fn reinterpret_as<E: Element>(&mut self) {
58        assert_eq!(
59            self.vtype.byte_width(),
60            size_of::<E>(),
61            "Cannot reinterpret {} as {}",
62            self.vtype,
63            E::vtype()
64        );
65        self.vtype = E::vtype();
66    }
67}
68
69pub struct ViewMut<'a> {
70    /// The physical type of the vector, which defines how the elements are stored.
71    pub(super) vtype: VType,
72    /// A pointer to the allocated elements buffer.
73    /// Alignment is at least the size of the element type.
74    /// The capacity of the elements buffer is N * `size_of::<T>()` where T is the element type.
75    // TODO(ngates): it would be nice to guarantee _wider_ alignment, ideally 128 bytes, so that
76    //  we can use aligned load/store instructions for wide SIMD lanes.
77    pub(super) elements: *mut u8,
78    /// The validity mask for the vector, indicating which elements in the buffer are valid.
79    /// This value can be `None` if the expected DType is `NonNullable`.
80    pub(super) validity: Option<BitViewMut<'a>>,
81
82    /// Additional buffers of data used by the vector, such as string data.
83    // TODO(ngates): ideally these buffers are compressed somehow? E.g. using FSST?
84    #[allow(dead_code)]
85    pub(super) data: Vec<ByteBuffer>,
86
87    /// Marker defining the lifetime of the contents of the vector.
88    pub(super) _marker: std::marker::PhantomData<&'a mut ()>,
89}
90
91impl<'a> ViewMut<'a> {
92    pub fn new<E: Element>(elements: &'a mut [E], validity: Option<BitViewMut<'a>>) -> Self {
93        assert_eq!(elements.len(), N);
94        Self {
95            vtype: E::vtype(),
96            elements: elements.as_mut_ptr().cast(),
97            validity,
98            data: vec![],
99            _marker: Default::default(),
100        }
101    }
102
103    /// Re-interpret cast the vector into a new type where the element has the same width.
104    #[inline(always)]
105    pub fn reinterpret_as<E: Element>(&mut self) {
106        assert_eq!(
107            self.vtype.byte_width(),
108            size_of::<E>(),
109            "Cannot reinterpret {} as {}",
110            self.vtype,
111            E::vtype()
112        );
113        self.vtype = E::vtype();
114    }
115
116    /// Returns an immutable slice of the elements in the vector.
117    #[inline(always)]
118    pub fn as_slice<E: Element>(&self) -> &'a [E] {
119        debug_assert_eq!(self.vtype, E::vtype(), "Invalid type for canonical view");
120        unsafe { std::slice::from_raw_parts(self.elements.cast::<E>(), N) }
121    }
122
123    /// Returns a mutable slice of the elements in the vector, allowing for modification.
124    #[inline(always)]
125    pub fn as_slice_mut<E: Element>(&mut self) -> &'a mut [E] {
126        debug_assert_eq!(self.vtype, E::vtype(), "Invalid type for canonical view");
127        unsafe { std::slice::from_raw_parts_mut(self.elements.cast::<E>(), N) }
128    }
129
130    /// Access the validity mask of the vector.
131    ///
132    /// ## Panics
133    ///
134    /// Panics if the vector does not support validity, i.e. if the DType was non-nullable when
135    /// it was created.
136    pub fn validity(&mut self) -> &mut BitViewMut<'a> {
137        self.validity
138            .as_mut()
139            .vortex_expect("Vector does not support validity")
140    }
141
142    pub fn add_buffer(&mut self, buffer: ByteBuffer) {
143        self.data.push(buffer);
144    }
145
146    /// Flatten the view by bringing the selected elements of the mask to the beginning of
147    /// the elements buffer.
148    ///
149    /// FIXME(ngates): also need to select validity bits.
150    pub fn select_mask<E: Element + Display>(&mut self, mask: &BitView) {
151        assert_eq!(
152            self.vtype,
153            E::vtype(),
154            "ViewMut::flatten_mask: type mismatch"
155        );
156
157        match mask.true_count() {
158            0 => {
159                // If the mask has no true bits, we set the length to 0.
160            }
161            N => {
162                // If the mask has N true bits, we copy all elements.
163            }
164            n if n > 3 * N / 4 => {
165                // High density: use iter_zeros to compact by removing gaps
166                let slice = self.as_slice_mut::<E>();
167                let mut write_idx = 0;
168                let mut read_idx = 0;
169
170                mask.iter_zeros(|zero_idx| {
171                    // Copy elements from read_idx to zero_idx (exclusive) to write_idx
172                    let count = zero_idx - read_idx;
173                    unsafe {
174                        // SAFETY: We assume that the elements are of type E and that the view is valid.
175                        // Using memmove for potentially overlapping regions
176                        std::ptr::copy(
177                            slice.as_ptr().add(read_idx),
178                            slice.as_mut_ptr().add(write_idx),
179                            count,
180                        );
181                        write_idx += count;
182                    }
183                    read_idx = zero_idx + 1;
184                });
185
186                // Copy any remaining elements after the last zero
187                unsafe {
188                    std::ptr::copy(
189                        slice.as_ptr().add(read_idx),
190                        slice.as_mut_ptr().add(write_idx),
191                        N - read_idx,
192                    );
193                }
194            }
195            _ => {
196                let mut offset = 0;
197                let slice = self.as_slice_mut::<E>();
198                mask.iter_ones(|idx| {
199                    unsafe {
200                        // SAFETY: We assume that the elements are of type E and that the view is valid.
201                        let value = *slice.get_unchecked(idx);
202                        // TODO(joe): use ptr increment (not offset).
203                        *slice.get_unchecked_mut(offset) = value;
204
205                        offset += 1;
206                    }
207                });
208            }
209        }
210    }
211}