Skip to main content

vortex_array/arrays/variant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5mod operations;
6mod validity;
7
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_proto::dtype as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16use vortex_utils::aliases::hash_set::HashSet;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::ExecutionResult;
21use crate::array::Array;
22use crate::array::ArrayId;
23use crate::array::ArrayParts;
24use crate::array::ArrayView;
25use crate::array::EmptyArrayData;
26use crate::array::VTable;
27use crate::array::with_empty_buffers;
28use crate::arrays::variant::CORE_STORAGE_SLOT;
29use crate::arrays::variant::NUM_SLOTS;
30use crate::arrays::variant::SHREDDED_SLOT;
31use crate::arrays::variant::SLOT_NAMES;
32use crate::arrays::variant::compute::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::dtype::FieldName;
36use crate::dtype::FieldNames;
37use crate::dtype::Nullability;
38use crate::dtype::StructFields;
39use crate::scalar::Scalar;
40use crate::scalar::ScalarValue;
41use crate::serde::ArrayChildren;
42
43/// A [`Variant`]-encoded Vortex array.
44pub type VariantArray = Array<Variant>;
45
46pub(crate) fn initialize(session: &VortexSession) {
47    kernel::initialize(session);
48}
49
50#[derive(Clone, Debug)]
51pub struct Variant;
52
53#[derive(Clone, prost::Message)]
54struct VariantMetadataProto {
55    #[prost(message, optional, tag = "1")]
56    pub shredded_dtype: Option<pb::DType>,
57}
58
59impl VTable for Variant {
60    type TypedArrayData = EmptyArrayData;
61
62    type OperationsVTable = Self;
63
64    type ValidityVTable = Self;
65
66    fn id(&self) -> ArrayId {
67        static ID: CachedId = CachedId::new("vortex.variant");
68        *ID
69    }
70
71    fn validate(
72        &self,
73        _data: &Self::TypedArrayData,
74        dtype: &DType,
75        len: usize,
76        slots: &[Option<ArrayRef>],
77    ) -> VortexResult<()> {
78        vortex_ensure!(
79            slots.len() == NUM_SLOTS,
80            "VariantArray expects {NUM_SLOTS} slots, got {}",
81            slots.len()
82        );
83        vortex_ensure!(
84            slots[CORE_STORAGE_SLOT].is_some(),
85            "VariantArray core_storage slot must be present"
86        );
87        let core_storage = slots[CORE_STORAGE_SLOT]
88            .as_ref()
89            .vortex_expect("validated core_storage slot presence");
90        vortex_ensure!(
91            matches!(dtype, DType::Variant(_)),
92            "Expected Variant DType, got {dtype}"
93        );
94        vortex_ensure!(
95            core_storage.dtype() == dtype,
96            "VariantArray core_storage dtype {} does not match outer dtype {}",
97            core_storage.dtype(),
98            dtype
99        );
100        vortex_ensure!(
101            core_storage.len() == len,
102            "VariantArray core_storage length {} does not match outer length {}",
103            core_storage.len(),
104            len
105        );
106        if let Some(shredded) = slots[SHREDDED_SLOT].as_ref() {
107            vortex_ensure!(
108                shredded.len() == len,
109                "VariantArray shredded length {} does not match outer length {}",
110                shredded.len(),
111                len
112            );
113        }
114        Ok(())
115    }
116
117    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
118        0
119    }
120
121    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
122        vortex_panic!("VariantArray buffer index {idx} out of bounds")
123    }
124
125    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
126        None
127    }
128
129    fn with_buffers(
130        &self,
131        array: ArrayView<'_, Self>,
132        buffers: &[BufferHandle],
133    ) -> VortexResult<ArrayParts<Self>> {
134        with_empty_buffers(self, array, buffers)
135    }
136
137    fn serialize(
138        array: ArrayView<'_, Self>,
139        _session: &VortexSession,
140    ) -> VortexResult<Option<Vec<u8>>> {
141        let shredded_dtype = array.slots()[SHREDDED_SLOT]
142            .as_ref()
143            .map(|shredded| shredded.dtype().try_into())
144            .transpose()?;
145        Ok(Some(
146            VariantMetadataProto { shredded_dtype }.encode_to_vec(),
147        ))
148    }
149
150    fn deserialize(
151        &self,
152        dtype: &DType,
153        len: usize,
154        metadata: &[u8],
155        buffers: &[BufferHandle],
156        children: &dyn ArrayChildren,
157        session: &VortexSession,
158    ) -> VortexResult<ArrayParts<Self>> {
159        vortex_ensure!(
160            buffers.is_empty(),
161            "VariantArray expects 0 buffers, got {}",
162            buffers.len()
163        );
164        let proto = VariantMetadataProto::decode(metadata)?;
165        let shredded_dtype = proto
166            .shredded_dtype
167            .as_ref()
168            .map(|dtype| DType::from_proto(dtype, session))
169            .transpose()?;
170        vortex_ensure!(matches!(dtype, DType::Variant(_)), "Expected Variant DType");
171        let expected_children = 1 + usize::from(shredded_dtype.is_some());
172        vortex_ensure!(
173            children.len() == expected_children,
174            "Expected {} children, got {}",
175            expected_children,
176            children.len(),
177        );
178        let core_storage = children.get(0, dtype, len)?;
179        let shredded = shredded_dtype
180            .map(|dtype| children.get(1, &dtype, len))
181            .transpose()?;
182        Ok(
183            ArrayParts::new(self.clone(), dtype.clone(), len, EmptyArrayData)
184                .with_slots(vec![Some(core_storage), shredded].into()),
185        )
186    }
187
188    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
189        match SLOT_NAMES.get(idx) {
190            Some(name) => (*name).to_string(),
191            None => vortex_panic!("VariantArray slot_name index {idx} out of bounds"),
192        }
193    }
194
195    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
196        Ok(ExecutionResult::done(array))
197    }
198
199    fn reduce_parent(
200        array: ArrayView<'_, Self>,
201        parent: &ArrayRef,
202        child_idx: usize,
203    ) -> VortexResult<Option<ArrayRef>> {
204        RULES.evaluate(array, parent, child_idx)
205    }
206}
207
208fn merge_typed_scalar_as_variant(
209    typed_scalar: Scalar,
210    fallback_scalar: Option<Scalar>,
211    dtype: &DType,
212) -> VortexResult<Scalar> {
213    let scalar = if typed_scalar.is_null() {
214        fallback_scalar.unwrap_or_else(|| Scalar::null(dtype.clone()))
215    } else if matches!(
216        typed_scalar.dtype(),
217        DType::List(..) | DType::FixedSizeList(..)
218    ) {
219        Scalar::variant(typed_list_as_variant_payload(typed_scalar)?)
220    } else if typed_scalar.dtype().is_struct() {
221        merge_typed_object_as_variant(typed_scalar, fallback_scalar)?
222    } else if typed_scalar.dtype().is_variant() {
223        typed_scalar
224    } else {
225        Scalar::variant(typed_scalar)
226    };
227
228    if scalar.dtype() == dtype {
229        Ok(scalar)
230    } else {
231        scalar.cast(dtype)
232    }
233}
234
235fn typed_list_as_variant_payload(typed_scalar: Scalar) -> VortexResult<Scalar> {
236    let list = typed_scalar.as_list();
237    let elements = list
238        .elements()
239        .unwrap_or_default()
240        .into_iter()
241        .map(|element| {
242            if element.dtype().is_variant() {
243                element
244            } else {
245                Scalar::variant(element)
246            }
247        })
248        .collect();
249    Ok(Scalar::list(
250        DType::Variant(Nullability::NonNullable),
251        elements,
252        Nullability::NonNullable,
253    ))
254}
255
256fn merge_typed_object_as_variant(
257    typed_scalar: Scalar,
258    fallback_scalar: Option<Scalar>,
259) -> VortexResult<Scalar> {
260    let fallback_inner = fallback_scalar
261        .as_ref()
262        .and_then(|scalar| scalar.as_variant().value())
263        .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null());
264    let Some(fallback_inner) = fallback_inner else {
265        return Ok(Scalar::variant(typed_scalar));
266    };
267
268    merge_struct_payload(&typed_scalar, Some(fallback_inner)).map(Scalar::variant)
269}
270
271fn merge_struct_payload(typed: &Scalar, raw: Option<&Scalar>) -> VortexResult<Scalar> {
272    let typed_struct = typed.as_struct();
273    let raw_struct = raw
274        .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null())
275        .map(Scalar::as_struct);
276    let mut present_typed_fields = HashSet::new();
277    let mut names = Vec::new();
278    let mut values = Vec::new();
279
280    for name in typed_struct.names().iter() {
281        let Some(typed_field) = typed_struct.field(name.as_ref()) else {
282            continue;
283        };
284        if typed_field.is_null() {
285            continue;
286        }
287
288        let raw_field = raw_struct.and_then(|raw_struct| raw_struct.field(name.as_ref()));
289        let raw_payload = raw_field.as_ref().and_then(|scalar| {
290            if scalar.dtype().is_variant() {
291                scalar.as_variant().value()
292            } else {
293                Some(scalar)
294            }
295        });
296        let field = if typed_field.dtype().is_struct()
297            && raw_payload.is_some_and(|raw| raw.dtype().is_struct() && !raw.is_null())
298        {
299            Scalar::variant(merge_struct_payload(&typed_field, raw_payload)?)
300        } else if typed_field.dtype().is_variant() {
301            typed_field.cast(&DType::Variant(Nullability::NonNullable))?
302        } else {
303            Scalar::variant(typed_field)
304        };
305
306        present_typed_fields.insert(name.as_ref().to_string());
307        names.push(FieldName::from(name.as_ref()));
308        values.push(field.into_value());
309    }
310
311    if let Some(raw_struct) = raw_struct {
312        for name in raw_struct.names().iter() {
313            if present_typed_fields.contains(name.as_ref()) {
314                continue;
315            }
316            let Some(raw_field) = raw_struct.field(name.as_ref()) else {
317                continue;
318            };
319            if raw_field.is_null() {
320                continue;
321            }
322            let raw_field = if raw_field.dtype().is_variant() {
323                raw_field.cast(&DType::Variant(Nullability::NonNullable))?
324            } else {
325                Scalar::variant(raw_field)
326            };
327            names.push(FieldName::from(name.as_ref()));
328            values.push(raw_field.into_value());
329        }
330    }
331
332    let fields = StructFields::new(
333        FieldNames::from(names),
334        vec![DType::Variant(Nullability::NonNullable); values.len()],
335    );
336    Scalar::try_new(
337        DType::Struct(fields, Nullability::NonNullable),
338        Some(ScalarValue::Tuple(values)),
339    )
340}
341
342#[cfg(test)]
343mod tests {}