Skip to main content

vortex_array/arrays/dict/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hasher;
5
6use prost::Message;
7use smallvec::smallvec;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15
16use super::DictData;
17use super::DictMetadata;
18use super::DictOwnedExt;
19use super::DictParts;
20use super::array::DictSlots;
21use super::array::DictSlotsView;
22use crate::AnyCanonical;
23use crate::ArrayEq;
24use crate::ArrayHash;
25use crate::ArrayRef;
26use crate::Canonical;
27use crate::EqMode;
28use crate::IntoArray;
29use crate::array::Array;
30use crate::array::ArrayId;
31use crate::array::ArrayParts;
32use crate::array::ArrayView;
33use crate::array::VTable;
34use crate::array::with_empty_buffers;
35use crate::arrays::ConstantArray;
36use crate::arrays::Primitive;
37use crate::arrays::dict::DictArrayExt;
38use crate::arrays::dict::DictArraySlotsExt;
39use crate::arrays::dict::compute::rules::PARENT_RULES;
40use crate::arrays::dict::execute::take_canonical;
41use crate::buffer::BufferHandle;
42use crate::builders::ArrayBuilder;
43use crate::dtype::DType;
44use crate::dtype::Nullability;
45use crate::dtype::PType;
46use crate::executor::ExecutionCtx;
47use crate::executor::ExecutionResult;
48use crate::require_child;
49use crate::scalar::Scalar;
50use crate::serde::ArrayChildren;
51
52mod kernel;
53mod operations;
54mod validity;
55
56/// A [`Dict`]-encoded Vortex array.
57pub type DictArray = Array<Dict>;
58
59pub(crate) fn initialize(session: &VortexSession) {
60    kernel::initialize(session);
61}
62
63#[derive(Clone, Debug)]
64pub struct Dict;
65
66impl ArrayHash for DictData {
67    fn array_hash<H: Hasher>(&self, _state: &mut H, _accuracy: EqMode) {}
68}
69
70impl ArrayEq for DictData {
71    fn array_eq(&self, _other: &Self, _accuracy: EqMode) -> bool {
72        true
73    }
74}
75
76impl VTable for Dict {
77    type TypedArrayData = DictData;
78
79    type OperationsVTable = Self;
80    type ValidityVTable = Self;
81
82    fn id(&self) -> ArrayId {
83        static ID: CachedId = CachedId::new("vortex.dict");
84        *ID
85    }
86
87    fn validate(
88        &self,
89        _data: &DictData,
90        dtype: &DType,
91        len: usize,
92        slots: &[Option<ArrayRef>],
93    ) -> VortexResult<()> {
94        let view = DictSlotsView::from_slots(slots);
95        let codes = view.codes;
96        let values = view.values;
97        vortex_ensure!(codes.len() == len, "DictArray codes length mismatch");
98        vortex_ensure!(
99            values
100                .dtype()
101                .union_nullability(codes.dtype().nullability())
102                == *dtype,
103            "DictArray dtype does not match codes/values dtype"
104        );
105        Ok(())
106    }
107
108    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
109        0
110    }
111
112    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
113        vortex_panic!("DictArray buffer index {idx} out of bounds")
114    }
115
116    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
117        None
118    }
119
120    fn with_buffers(
121        &self,
122        array: ArrayView<'_, Self>,
123        buffers: &[BufferHandle],
124    ) -> VortexResult<ArrayParts<Self>> {
125        with_empty_buffers(self, array, buffers)
126    }
127
128    fn serialize(
129        array: ArrayView<'_, Self>,
130        _session: &VortexSession,
131    ) -> VortexResult<Option<Vec<u8>>> {
132        Ok(Some(
133            DictMetadata {
134                codes_ptype: PType::try_from(array.codes().dtype())? as i32,
135                values_len: u32::try_from(array.values().len()).map_err(|_| {
136                    vortex_err!(
137                        "Dictionary values size {} overflowed u32",
138                        array.values().len()
139                    )
140                })?,
141                is_nullable_codes: Some(array.codes().dtype().is_nullable()),
142                all_values_referenced: Some(array.has_all_values_referenced()),
143            }
144            .encode_to_vec(),
145        ))
146    }
147
148    fn deserialize(
149        &self,
150        dtype: &DType,
151        len: usize,
152        metadata: &[u8],
153        _buffers: &[BufferHandle],
154        children: &dyn ArrayChildren,
155        _session: &VortexSession,
156    ) -> VortexResult<ArrayParts<Self>> {
157        let metadata = DictMetadata::decode(metadata)?;
158        if children.len() != 2 {
159            vortex_bail!(
160                "Expected 2 children for dict encoding, found {}",
161                children.len()
162            )
163        }
164        let codes_nullable = metadata
165            .is_nullable_codes
166            .map(Nullability::from)
167            // If no `is_nullable_codes` metadata use the nullability of the values
168            // (and whole array) as before.
169            .unwrap_or_else(|| dtype.nullability());
170        let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable);
171        let codes = children.get(0, &codes_dtype, len)?;
172        let values = children.get(1, dtype, metadata.values_len as usize)?;
173        let all_values_referenced = metadata.all_values_referenced.unwrap_or(false);
174
175        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, unsafe {
176            DictData::new_unchecked().set_all_values_referenced(all_values_referenced)
177        })
178        .with_slots(smallvec![Some(codes), Some(values)]))
179    }
180
181    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
182        DictSlots::NAMES[idx].to_string()
183    }
184
185    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
186        if array.is_empty() {
187            let result_dtype = array
188                .dtype()
189                .union_nullability(array.codes().dtype().nullability());
190            return Ok(ExecutionResult::done(Canonical::empty(&result_dtype)));
191        }
192
193        let array = require_child!(array, array.codes(), DictSlots::CODES => Primitive);
194
195        if array.codes().validity()?.definitely_all_null() {
196            return Ok(ExecutionResult::done(ConstantArray::new(
197                Scalar::null(array.dtype().as_nullable()),
198                array.codes().len(),
199            )));
200        }
201
202        let array = require_child!(array, array.values(), DictSlots::VALUES => AnyCanonical);
203
204        let DictParts { values, codes, .. } = array.into_parts();
205
206        Ok(ExecutionResult::done(take_canonical(
207            values.as_::<AnyCanonical>(),
208            &codes.downcast::<Primitive>(),
209            ctx,
210        )?))
211    }
212
213    fn append_to_builder(
214        array: ArrayView<'_, Self>,
215        builder: &mut dyn ArrayBuilder,
216        ctx: &mut ExecutionCtx,
217    ) -> VortexResult<()> {
218        if !array.is_empty()
219            && let (Some(codes), Some(values)) = (
220                array.codes().as_opt::<Primitive>(),
221                array.values().as_opt::<AnyCanonical>(),
222            )
223            && !codes.validity()?.definitely_all_null()
224        {
225            let codes = codes.into_owned();
226            let canonical = take_canonical(values, &codes, ctx)?.into_array();
227            canonical.append_to_builder(builder, ctx)?;
228            return Ok(());
229        }
230
231        let canonical = array
232            .array()
233            .clone()
234            .execute::<Canonical>(ctx)?
235            .into_array();
236        canonical.append_to_builder(builder, ctx)?;
237        Ok(())
238    }
239
240    fn reduce_parent(
241        array: ArrayView<'_, Self>,
242        parent: &ArrayRef,
243        child_idx: usize,
244    ) -> VortexResult<Option<ArrayRef>> {
245        PARENT_RULES.evaluate(array, parent, child_idx)
246    }
247}