Skip to main content

vortex_array/arrays/primitive/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_error::vortex_ensure;
7use vortex_error::vortex_panic;
8
9use crate::ArrayParts;
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::ExecutionResult;
13use crate::array::Array;
14use crate::array::ArrayView;
15use crate::array::VTable;
16use crate::arrays::primitive::PrimitiveData;
17use crate::buffer::BufferHandle;
18use crate::builders::ArrayBuilder;
19use crate::builders::PrimitiveBuilder;
20use crate::dtype::DType;
21use crate::dtype::PType;
22use crate::match_each_native_ptype;
23use crate::serde::ArrayChildren;
24use crate::validity::Validity;
25mod kernel;
26mod operations;
27mod validity;
28
29use std::hash::Hasher;
30
31use vortex_buffer::Alignment;
32use vortex_session::VortexSession;
33use vortex_session::registry::CachedId;
34
35use crate::EqMode;
36use crate::array::ArrayId;
37use crate::arrays::primitive::array::SLOT_NAMES;
38use crate::arrays::primitive::compute::rules::RULES;
39use crate::hash::ArrayEq;
40use crate::hash::ArrayHash;
41
42/// A [`Primitive`]-encoded Vortex array.
43pub type PrimitiveArray = Array<Primitive>;
44
45pub(crate) fn initialize(session: &VortexSession) {
46    kernel::initialize(session);
47}
48
49impl ArrayHash for PrimitiveData {
50    fn array_hash<H: Hasher>(&self, state: &mut H, accuracy: EqMode) {
51        self.buffer.array_hash(state, accuracy);
52    }
53}
54
55impl ArrayEq for PrimitiveData {
56    fn array_eq(&self, other: &Self, accuracy: EqMode) -> bool {
57        self.buffer.array_eq(&other.buffer, accuracy)
58    }
59}
60
61impl VTable for Primitive {
62    type TypedArrayData = PrimitiveData;
63
64    type OperationsVTable = Self;
65    type ValidityVTable = Self;
66
67    fn id(&self) -> ArrayId {
68        static ID: CachedId = CachedId::new("vortex.primitive");
69        *ID
70    }
71
72    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
73        1
74    }
75
76    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
77        match idx {
78            0 => array.buffer_handle().clone(),
79            _ => vortex_panic!("PrimitiveArray buffer index {idx} out of bounds"),
80        }
81    }
82
83    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
84        match idx {
85            0 => Some("values".to_string()),
86            _ => None,
87        }
88    }
89
90    fn with_buffers(
91        &self,
92        array: ArrayView<'_, Self>,
93        buffers: &[BufferHandle],
94    ) -> VortexResult<ArrayParts<Self>> {
95        vortex_ensure!(
96            buffers.len() == 1,
97            "Expected 1 buffer, got {}",
98            buffers.len()
99        );
100        let mut data = array.data().clone();
101        data.buffer = buffers[0].clone();
102        Ok(
103            ArrayParts::new(self.clone(), array.dtype().clone(), array.len(), data)
104                .with_slots(array.slots().iter().cloned().collect()),
105        )
106    }
107
108    fn serialize(
109        _array: ArrayView<'_, Self>,
110        _session: &VortexSession,
111    ) -> VortexResult<Option<Vec<u8>>> {
112        Ok(Some(vec![]))
113    }
114
115    fn validate(
116        &self,
117        data: &PrimitiveData,
118        dtype: &DType,
119        len: usize,
120        slots: &[Option<ArrayRef>],
121    ) -> VortexResult<()> {
122        let DType::Primitive(_, nullability) = dtype else {
123            vortex_bail!("Expected primitive dtype, got {dtype:?}");
124        };
125        vortex_ensure!(
126            data.len() == len,
127            "PrimitiveArray length {} does not match outer length {}",
128            data.len(),
129            len
130        );
131        let validity = crate::array::child_to_validity(slots[0].as_ref(), *nullability);
132        if let Some(validity_len) = validity.maybe_len() {
133            vortex_ensure!(
134                validity_len == len,
135                "PrimitiveArray validity len {} does not match outer length {}",
136                validity_len,
137                len
138            );
139        }
140
141        Ok(())
142    }
143
144    fn deserialize(
145        &self,
146        dtype: &DType,
147        len: usize,
148        metadata: &[u8],
149
150        buffers: &[BufferHandle],
151        children: &dyn ArrayChildren,
152        _session: &VortexSession,
153    ) -> VortexResult<ArrayParts<Self>> {
154        if !metadata.is_empty() {
155            vortex_bail!(
156                "PrimitiveArray expects empty metadata, got {} bytes",
157                metadata.len()
158            );
159        }
160        if buffers.len() != 1 {
161            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
162        }
163        let buffer = buffers[0].clone();
164
165        let validity = if children.is_empty() {
166            Validity::from(dtype.nullability())
167        } else if children.len() == 1 {
168            let validity = children.get(0, &Validity::DTYPE, len)?;
169            Validity::Array(validity)
170        } else {
171            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
172        };
173
174        let ptype = PType::try_from(dtype)?;
175
176        vortex_ensure!(
177            buffer.is_aligned_to(Alignment::new(ptype.byte_width())),
178            "Misaligned buffer cannot be used to build PrimitiveArray of {ptype}"
179        );
180
181        if buffer.len() != ptype.byte_width() * len {
182            vortex_bail!(
183                "Buffer length {} does not match expected length {} for {}, {}",
184                buffer.len(),
185                ptype.byte_width() * len,
186                ptype.byte_width(),
187                len,
188            );
189        }
190
191        vortex_ensure!(
192            buffer.is_aligned_to(Alignment::new(ptype.byte_width())),
193            "PrimitiveArray::build: Buffer (align={}) must be aligned to {}",
194            buffer.alignment(),
195            ptype.byte_width()
196        );
197
198        // SAFETY: checked ahead of time
199        let slots = PrimitiveData::make_slots(&validity, len);
200        let data = unsafe { PrimitiveData::new_unchecked_from_handle(buffer, ptype, validity) };
201        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
202    }
203
204    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
205        SLOT_NAMES[idx].to_string()
206    }
207
208    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
209        Ok(ExecutionResult::done(array))
210    }
211
212    fn append_to_builder(
213        array: ArrayView<'_, Self>,
214        builder: &mut dyn ArrayBuilder,
215        ctx: &mut ExecutionCtx,
216    ) -> VortexResult<()> {
217        match_each_native_ptype!(array.ptype(), |P| {
218            if let Some(builder) = builder.as_any_mut().downcast_mut::<PrimitiveBuilder<P>>() {
219                return builder.append_primitive_array(&array.into_owned(), ctx);
220            }
221        });
222
223        builder.extend_from_array(array.as_ref());
224        Ok(())
225    }
226
227    fn reduce_parent(
228        array: ArrayView<'_, Self>,
229        parent: &ArrayRef,
230        child_idx: usize,
231    ) -> VortexResult<Option<ArrayRef>> {
232        RULES.evaluate(array, parent, child_idx)
233    }
234}
235
236#[derive(Clone, Debug)]
237pub struct Primitive;
238
239#[cfg(test)]
240mod tests {
241    use vortex_buffer::ByteBufferMut;
242    use vortex_buffer::buffer;
243    use vortex_error::VortexResult;
244    use vortex_session::registry::ReadContext;
245
246    use crate::ArrayContext;
247    use crate::IntoArray;
248    use crate::VortexSessionExecute;
249    use crate::array_session;
250    use crate::arrays::PrimitiveArray;
251    use crate::assert_arrays_eq;
252    use crate::buffer::BufferHandle;
253    use crate::serde::SerializeOptions;
254    use crate::serde::SerializedArray;
255    use crate::validity::Validity;
256
257    #[test]
258    fn test_nullable_primitive_serde_roundtrip() {
259        let session = array_session();
260        let mut ctx = session.create_execution_ctx();
261        let array = PrimitiveArray::new(
262            buffer![1i32, 2, 3, 4],
263            Validity::from_iter([true, false, true, false]),
264        );
265        let dtype = array.dtype().clone();
266        let len = array.len();
267
268        let array_ctx = ArrayContext::empty();
269        let serialized = array
270            .clone()
271            .into_array()
272            .serialize(&array_ctx, &session, &SerializeOptions::default())
273            .unwrap();
274
275        let mut concat = ByteBufferMut::empty();
276        for buf in serialized {
277            concat.extend_from_slice(buf.as_ref());
278        }
279        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
280        let decoded = parts
281            .decode(&dtype, len, &ReadContext::new(array_ctx.to_ids()), &session)
282            .unwrap();
283
284        assert_arrays_eq!(decoded, array, &mut ctx);
285    }
286
287    #[test]
288    fn test_with_buffers_replaces_primitive_buffer_with_equivalent_contents() -> VortexResult<()> {
289        let session = array_session();
290        let mut ctx = session.create_execution_ctx();
291
292        let array = PrimitiveArray::from_iter([1i32, 2, 3, 4]).into_array();
293        let replacement = BufferHandle::new_host(buffer![1i32, 2, 3, 4].into_byte_buffer());
294        // SAFETY: the replacement buffer contains the same logical values as the original array;
295        // only the buffer handle changes.
296        let rewritten = unsafe { array.with_buffers([replacement]) }?;
297        let expected = PrimitiveArray::from_iter([1i32, 2, 3, 4]);
298
299        assert_arrays_eq!(rewritten, expected, &mut ctx);
300        Ok(())
301    }
302
303    #[test]
304    fn test_with_buffers_rejects_length_change() {
305        let array = PrimitiveArray::from_iter([1i32, 2, 3, 4]).into_array();
306        let replacement = BufferHandle::new_host(buffer![10i32, 20, 30].into_byte_buffer());
307
308        // SAFETY: this call is expected to fail the checked buffer length invariant before any
309        // rewritten array is returned or observed.
310        assert!(unsafe { array.with_buffers([replacement]) }.is_err());
311    }
312}