Skip to main content

vortex_array/arrays/variant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5mod operations;
6mod validity;
7
8use kernel::PARENT_KERNELS;
9use prost::Message;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_proto::dtype as pb;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17use vortex_utils::aliases::hash_set::HashSet;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::ExecutionResult;
22use crate::array::Array;
23use crate::array::ArrayId;
24use crate::array::ArrayParts;
25use crate::array::ArrayView;
26use crate::array::EmptyArrayData;
27use crate::array::VTable;
28use crate::arrays::variant::CORE_STORAGE_SLOT;
29use crate::arrays::variant::NUM_SLOTS;
30use crate::arrays::variant::SHREDDED_SLOT;
31use crate::arrays::variant::SLOT_NAMES;
32use crate::arrays::variant::compute::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::dtype::FieldName;
36use crate::dtype::FieldNames;
37use crate::dtype::Nullability;
38use crate::dtype::StructFields;
39use crate::scalar::Scalar;
40use crate::scalar::ScalarValue;
41use crate::serde::ArrayChildren;
42
43/// A [`Variant`]-encoded Vortex array.
44pub type VariantArray = Array<Variant>;
45
46#[derive(Clone, Debug)]
47pub struct Variant;
48
49#[derive(Clone, prost::Message)]
50struct VariantMetadataProto {
51    #[prost(message, optional, tag = "1")]
52    pub shredded_dtype: Option<pb::DType>,
53}
54
55impl VTable for Variant {
56    type TypedArrayData = EmptyArrayData;
57
58    type OperationsVTable = Self;
59
60    type ValidityVTable = Self;
61
62    fn id(&self) -> ArrayId {
63        static ID: CachedId = CachedId::new("vortex.variant");
64        *ID
65    }
66
67    fn validate(
68        &self,
69        _data: &Self::TypedArrayData,
70        dtype: &DType,
71        len: usize,
72        slots: &[Option<ArrayRef>],
73    ) -> VortexResult<()> {
74        vortex_ensure!(
75            slots.len() == NUM_SLOTS,
76            "VariantArray expects {NUM_SLOTS} slots, got {}",
77            slots.len()
78        );
79        vortex_ensure!(
80            slots[CORE_STORAGE_SLOT].is_some(),
81            "VariantArray core_storage slot must be present"
82        );
83        let core_storage = slots[CORE_STORAGE_SLOT]
84            .as_ref()
85            .vortex_expect("validated core_storage slot presence");
86        vortex_ensure!(
87            matches!(dtype, DType::Variant(_)),
88            "Expected Variant DType, got {dtype}"
89        );
90        vortex_ensure!(
91            core_storage.dtype() == dtype,
92            "VariantArray core_storage dtype {} does not match outer dtype {}",
93            core_storage.dtype(),
94            dtype
95        );
96        vortex_ensure!(
97            core_storage.len() == len,
98            "VariantArray core_storage length {} does not match outer length {}",
99            core_storage.len(),
100            len
101        );
102        if let Some(shredded) = slots[SHREDDED_SLOT].as_ref() {
103            vortex_ensure!(
104                shredded.len() == len,
105                "VariantArray shredded length {} does not match outer length {}",
106                shredded.len(),
107                len
108            );
109        }
110        Ok(())
111    }
112
113    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
114        0
115    }
116
117    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
118        vortex_panic!("VariantArray buffer index {idx} out of bounds")
119    }
120
121    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
122        None
123    }
124
125    fn serialize(
126        array: ArrayView<'_, Self>,
127        _session: &VortexSession,
128    ) -> VortexResult<Option<Vec<u8>>> {
129        let shredded_dtype = array.slots()[SHREDDED_SLOT]
130            .as_ref()
131            .map(|shredded| shredded.dtype().try_into())
132            .transpose()?;
133        Ok(Some(
134            VariantMetadataProto { shredded_dtype }.encode_to_vec(),
135        ))
136    }
137
138    fn deserialize(
139        &self,
140        dtype: &DType,
141        len: usize,
142        metadata: &[u8],
143
144        buffers: &[BufferHandle],
145        children: &dyn ArrayChildren,
146        session: &VortexSession,
147    ) -> VortexResult<ArrayParts<Self>> {
148        vortex_ensure!(
149            buffers.is_empty(),
150            "VariantArray expects 0 buffers, got {}",
151            buffers.len()
152        );
153        let proto = VariantMetadataProto::decode(metadata)?;
154        let shredded_dtype = proto
155            .shredded_dtype
156            .as_ref()
157            .map(|dtype| DType::from_proto(dtype, session))
158            .transpose()?;
159        vortex_ensure!(matches!(dtype, DType::Variant(_)), "Expected Variant DType");
160        let expected_children = 1 + usize::from(shredded_dtype.is_some());
161        vortex_ensure!(
162            children.len() == expected_children,
163            "Expected {} children, got {}",
164            expected_children,
165            children.len(),
166        );
167        let core_storage = children.get(0, dtype, len)?;
168        let shredded = shredded_dtype
169            .map(|dtype| children.get(1, &dtype, len))
170            .transpose()?;
171        Ok(
172            ArrayParts::new(self.clone(), dtype.clone(), len, EmptyArrayData)
173                .with_slots(vec![Some(core_storage), shredded].into()),
174        )
175    }
176
177    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
178        match SLOT_NAMES.get(idx) {
179            Some(name) => (*name).to_string(),
180            None => vortex_panic!("VariantArray slot_name index {idx} out of bounds"),
181        }
182    }
183
184    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
185        Ok(ExecutionResult::done(array))
186    }
187
188    fn reduce_parent(
189        array: ArrayView<'_, Self>,
190        parent: &ArrayRef,
191        child_idx: usize,
192    ) -> VortexResult<Option<ArrayRef>> {
193        RULES.evaluate(array, parent, child_idx)
194    }
195
196    fn execute_parent(
197        array: ArrayView<'_, Self>,
198        parent: &ArrayRef,
199        child_idx: usize,
200        ctx: &mut ExecutionCtx,
201    ) -> VortexResult<Option<ArrayRef>> {
202        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
203    }
204}
205
206fn merge_typed_scalar_as_variant(
207    typed_scalar: Scalar,
208    fallback_scalar: Option<Scalar>,
209    dtype: &DType,
210) -> VortexResult<Scalar> {
211    let scalar = if typed_scalar.is_null() {
212        fallback_scalar.unwrap_or_else(|| Scalar::null(dtype.clone()))
213    } else if matches!(
214        typed_scalar.dtype(),
215        DType::List(..) | DType::FixedSizeList(..)
216    ) {
217        Scalar::variant(typed_list_as_variant_payload(typed_scalar)?)
218    } else if typed_scalar.dtype().is_struct() {
219        merge_typed_object_as_variant(typed_scalar, fallback_scalar)?
220    } else if typed_scalar.dtype().is_variant() {
221        typed_scalar
222    } else {
223        Scalar::variant(typed_scalar)
224    };
225
226    if scalar.dtype() == dtype {
227        Ok(scalar)
228    } else {
229        scalar.cast(dtype)
230    }
231}
232
233fn typed_list_as_variant_payload(typed_scalar: Scalar) -> VortexResult<Scalar> {
234    let list = typed_scalar.as_list();
235    let elements = list
236        .elements()
237        .unwrap_or_default()
238        .into_iter()
239        .map(|element| {
240            if element.dtype().is_variant() {
241                element
242            } else {
243                Scalar::variant(element)
244            }
245        })
246        .collect();
247    Ok(Scalar::list(
248        DType::Variant(Nullability::NonNullable),
249        elements,
250        Nullability::NonNullable,
251    ))
252}
253
254fn merge_typed_object_as_variant(
255    typed_scalar: Scalar,
256    fallback_scalar: Option<Scalar>,
257) -> VortexResult<Scalar> {
258    let fallback_inner = fallback_scalar
259        .as_ref()
260        .and_then(|scalar| scalar.as_variant().value())
261        .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null());
262    let Some(fallback_inner) = fallback_inner else {
263        return Ok(Scalar::variant(typed_scalar));
264    };
265
266    merge_struct_payload(&typed_scalar, Some(fallback_inner)).map(Scalar::variant)
267}
268
269fn merge_struct_payload(typed: &Scalar, raw: Option<&Scalar>) -> VortexResult<Scalar> {
270    let typed_struct = typed.as_struct();
271    let raw_struct = raw
272        .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null())
273        .map(Scalar::as_struct);
274    let mut present_typed_fields = HashSet::new();
275    let mut names = Vec::new();
276    let mut values = Vec::new();
277
278    for name in typed_struct.names().iter() {
279        let Some(typed_field) = typed_struct.field(name.as_ref()) else {
280            continue;
281        };
282        if typed_field.is_null() {
283            continue;
284        }
285
286        let raw_field = raw_struct.and_then(|raw_struct| raw_struct.field(name.as_ref()));
287        let raw_payload = raw_field.as_ref().and_then(|scalar| {
288            if scalar.dtype().is_variant() {
289                scalar.as_variant().value()
290            } else {
291                Some(scalar)
292            }
293        });
294        let field = if typed_field.dtype().is_struct()
295            && raw_payload.is_some_and(|raw| raw.dtype().is_struct() && !raw.is_null())
296        {
297            Scalar::variant(merge_struct_payload(&typed_field, raw_payload)?)
298        } else if typed_field.dtype().is_variant() {
299            typed_field.cast(&DType::Variant(Nullability::NonNullable))?
300        } else {
301            Scalar::variant(typed_field)
302        };
303
304        present_typed_fields.insert(name.as_ref().to_string());
305        names.push(FieldName::from(name.as_ref()));
306        values.push(field.into_value());
307    }
308
309    if let Some(raw_struct) = raw_struct {
310        for name in raw_struct.names().iter() {
311            if present_typed_fields.contains(name.as_ref()) {
312                continue;
313            }
314            let Some(raw_field) = raw_struct.field(name.as_ref()) else {
315                continue;
316            };
317            if raw_field.is_null() {
318                continue;
319            }
320            let raw_field = if raw_field.dtype().is_variant() {
321                raw_field.cast(&DType::Variant(Nullability::NonNullable))?
322            } else {
323                Scalar::variant(raw_field)
324            };
325            names.push(FieldName::from(name.as_ref()));
326            values.push(raw_field.into_value());
327        }
328    }
329
330    let fields = StructFields::new(
331        FieldNames::from(names),
332        vec![DType::Variant(Nullability::NonNullable); values.len()],
333    );
334    Scalar::try_new(
335        DType::Struct(fields, Nullability::NonNullable),
336        Some(ScalarValue::Tuple(values)),
337    )
338}
339
340#[cfg(test)]
341mod tests {}