Skip to main content

vortex_fastlanes/for/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_array::Array;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayId;
12use vortex_array::ArrayParts;
13use vortex_array::ArrayRef;
14use vortex_array::ArrayView;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionResult;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::arrays::PrimitiveArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::dtype::DType;
22use vortex_array::scalar::Scalar;
23use vortex_array::scalar::ScalarValue;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::vtable::VTable;
26use vortex_array::vtable::ValidityVTableFromChild;
27use vortex_error::VortexExpect;
28use vortex_error::VortexResult;
29use vortex_error::vortex_bail;
30use vortex_error::vortex_ensure;
31use vortex_error::vortex_panic;
32use vortex_session::VortexSession;
33use vortex_session::registry::CachedId;
34
35use crate::FoRData;
36use crate::r#for::array::FoRArrayExt;
37use crate::r#for::array::SLOT_NAMES;
38use crate::r#for::array::for_decompress::decompress;
39use crate::r#for::vtable::kernels::PARENT_KERNELS;
40use crate::r#for::vtable::rules::PARENT_RULES;
41
42mod kernels;
43mod operations;
44mod rules;
45mod slice;
46mod validity;
47
48/// A [`FoR`]-encoded Vortex array.
49pub type FoRArray = Array<FoR>;
50
51impl ArrayHash for FoRData {
52    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
53        self.reference.hash(state);
54    }
55}
56
57impl ArrayEq for FoRData {
58    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
59        self.reference == other.reference
60    }
61}
62
63impl VTable for FoR {
64    type ArrayData = FoRData;
65
66    type OperationsVTable = Self;
67    type ValidityVTable = ValidityVTableFromChild;
68
69    fn id(&self) -> ArrayId {
70        static ID: CachedId = CachedId::new("fastlanes.for");
71        *ID
72    }
73
74    fn validate(
75        &self,
76        data: &Self::ArrayData,
77        dtype: &DType,
78        len: usize,
79        slots: &[Option<ArrayRef>],
80    ) -> VortexResult<()> {
81        let encoded = slots[0].as_ref().vortex_expect("FoRArray encoded slot");
82        validate_parts(encoded.dtype(), encoded.len(), &data.reference, dtype, len)
83    }
84
85    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
86        0
87    }
88
89    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
90        vortex_panic!("FoRArray buffer index {idx} out of bounds")
91    }
92
93    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
94        None
95    }
96
97    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
98        SLOT_NAMES[idx].to_string()
99    }
100
101    fn serialize(
102        array: ArrayView<'_, Self>,
103        _session: &VortexSession,
104    ) -> VortexResult<Option<Vec<u8>>> {
105        // Note that we **only** serialize the optional scalar value (not including the dtype).
106        Ok(Some(ScalarValue::to_proto_bytes(
107            array.reference_scalar().value(),
108        )))
109    }
110
111    fn deserialize(
112        &self,
113        dtype: &DType,
114        len: usize,
115        metadata: &[u8],
116        buffers: &[BufferHandle],
117        children: &dyn ArrayChildren,
118        session: &VortexSession,
119    ) -> VortexResult<ArrayParts<Self>> {
120        vortex_ensure!(
121            buffers.is_empty(),
122            "FoRArray expects 0 buffers, got {}",
123            buffers.len()
124        );
125        if children.len() != 1 {
126            vortex_bail!(
127                "Expected 1 child for FoR encoding, found {}",
128                children.len()
129            )
130        }
131
132        let scalar_value = ScalarValue::from_proto_bytes(metadata, dtype, session)?;
133        let reference = Scalar::try_new(dtype.clone(), scalar_value)?;
134        let encoded = children.get(0, dtype, len)?;
135        let slots = vec![Some(encoded)];
136
137        let data = FoRData::try_new(reference)?;
138        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
139    }
140
141    fn reduce_parent(
142        array: ArrayView<'_, Self>,
143        parent: &ArrayRef,
144        child_idx: usize,
145    ) -> VortexResult<Option<ArrayRef>> {
146        PARENT_RULES.evaluate(array, parent, child_idx)
147    }
148
149    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
150        Ok(ExecutionResult::done(decompress(&array, ctx)?.into_array()))
151    }
152
153    fn execute_parent(
154        array: ArrayView<'_, Self>,
155        parent: &ArrayRef,
156        child_idx: usize,
157        ctx: &mut ExecutionCtx,
158    ) -> VortexResult<Option<ArrayRef>> {
159        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
160    }
161}
162
163#[derive(Clone, Debug)]
164pub struct FoR;
165
166impl FoR {
167    /// Construct a new FoR array from an encoded array and a reference scalar.
168    pub fn try_new(encoded: ArrayRef, reference: Scalar) -> VortexResult<FoRArray> {
169        vortex_ensure!(!reference.is_null(), "Reference value cannot be null");
170        let dtype = reference
171            .dtype()
172            .with_nullability(encoded.dtype().nullability());
173        let reference = reference.cast(&dtype)?;
174        let len = encoded.len();
175        let data = FoRData::try_new(reference)?;
176        let slots = vec![Some(encoded)];
177        Array::try_from_parts(ArrayParts::new(FoR, dtype, len, data).with_slots(slots))
178    }
179
180    /// Encode a primitive array using Frame of Reference encoding.
181    pub fn encode(array: PrimitiveArray) -> VortexResult<FoRArray> {
182        FoRData::encode(array)
183    }
184}
185
186fn validate_parts(
187    encoded_dtype: &DType,
188    encoded_len: usize,
189    reference: &Scalar,
190    dtype: &DType,
191    len: usize,
192) -> VortexResult<()> {
193    vortex_ensure!(dtype.is_int(), "FoR requires an integer dtype, got {dtype}");
194    vortex_ensure!(
195        reference.dtype() == dtype,
196        "FoR reference dtype mismatch: expected {dtype}, got {}",
197        reference.dtype()
198    );
199    vortex_ensure!(
200        encoded_dtype == dtype,
201        "FoR encoded dtype mismatch: expected {dtype}, got {}",
202        encoded_dtype
203    );
204    vortex_ensure!(
205        encoded_len == len,
206        "FoR encoded length mismatch: expected {len}, got {}",
207        encoded_len
208    );
209    Ok(())
210}
211
212#[cfg(test)]
213mod tests {
214    use vortex_array::scalar::ScalarValue;
215    use vortex_array::test_harness::check_metadata;
216
217    #[cfg_attr(miri, ignore)]
218    #[test]
219    fn test_for_metadata() {
220        let metadata: Vec<u8> = ScalarValue::to_proto_bytes(Some(&ScalarValue::from(i64::MAX)));
221        check_metadata("for.metadata", &metadata);
222    }
223}