Skip to main content

vortex_array/arrays/dict/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use kernel::PARENT_KERNELS;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use super::DictArray;
17use super::DictArrayParts;
18use super::DictMetadata;
19use super::take_canonical;
20use crate::AnyCanonical;
21use crate::ArrayRef;
22use crate::Canonical;
23use crate::DeserializeMetadata;
24use crate::DynArray;
25use crate::IntoArray;
26use crate::Precision;
27use crate::ProstMetadata;
28use crate::SerializeMetadata;
29use crate::arrays::ConstantArray;
30use crate::arrays::Primitive;
31use crate::arrays::dict::compute::rules::PARENT_RULES;
32use crate::buffer::BufferHandle;
33use crate::dtype::DType;
34use crate::dtype::Nullability;
35use crate::dtype::PType;
36use crate::executor::ExecutionCtx;
37use crate::executor::ExecutionResult;
38use crate::hash::ArrayEq;
39use crate::hash::ArrayHash;
40use crate::require_child;
41use crate::scalar::Scalar;
42use crate::serde::ArrayChildren;
43use crate::stats::StatsSetRef;
44use crate::vtable;
45use crate::vtable::Array;
46use crate::vtable::ArrayId;
47use crate::vtable::VTable;
48mod kernel;
49mod operations;
50mod validity;
51
52vtable!(Dict);
53
54#[derive(Clone, Debug)]
55pub struct Dict;
56
57impl Dict {
58    pub const ID: ArrayId = ArrayId::new_ref("vortex.dict");
59}
60
61impl VTable for Dict {
62    type Array = DictArray;
63
64    type Metadata = ProstMetadata<DictMetadata>;
65    type OperationsVTable = Self;
66    type ValidityVTable = Self;
67
68    fn vtable(_array: &Self::Array) -> &Self {
69        &Dict
70    }
71
72    fn id(&self) -> ArrayId {
73        Self::ID
74    }
75
76    fn len(array: &DictArray) -> usize {
77        array.codes.len()
78    }
79
80    fn dtype(array: &DictArray) -> &DType {
81        &array.dtype
82    }
83
84    fn stats(array: &DictArray) -> StatsSetRef<'_> {
85        array.stats_set.to_ref(array.as_ref())
86    }
87
88    fn array_hash<H: std::hash::Hasher>(array: &DictArray, state: &mut H, precision: Precision) {
89        array.dtype.hash(state);
90        array.codes.array_hash(state, precision);
91        array.values.array_hash(state, precision);
92    }
93
94    fn array_eq(array: &DictArray, other: &DictArray, precision: Precision) -> bool {
95        array.dtype == other.dtype
96            && array.codes.array_eq(&other.codes, precision)
97            && array.values.array_eq(&other.values, precision)
98    }
99
100    fn nbuffers(_array: &DictArray) -> usize {
101        0
102    }
103
104    fn buffer(_array: &DictArray, idx: usize) -> BufferHandle {
105        vortex_panic!("DictArray buffer index {idx} out of bounds")
106    }
107
108    fn buffer_name(_array: &DictArray, _idx: usize) -> Option<String> {
109        None
110    }
111
112    fn nchildren(_array: &DictArray) -> usize {
113        2
114    }
115
116    fn child(array: &DictArray, idx: usize) -> ArrayRef {
117        match idx {
118            0 => array.codes().clone(),
119            1 => array.values().clone(),
120            _ => vortex_panic!("DictArray child index {idx} out of bounds"),
121        }
122    }
123
124    fn child_name(_array: &DictArray, idx: usize) -> String {
125        match idx {
126            0 => "codes".to_string(),
127            1 => "values".to_string(),
128            _ => vortex_panic!("DictArray child_name index {idx} out of bounds"),
129        }
130    }
131
132    fn metadata(array: &DictArray) -> VortexResult<Self::Metadata> {
133        Ok(ProstMetadata(DictMetadata {
134            codes_ptype: PType::try_from(array.codes().dtype())? as i32,
135            values_len: u32::try_from(array.values().len()).map_err(|_| {
136                vortex_err!(
137                    "Dictionary values size {} overflowed u32",
138                    array.values().len()
139                )
140            })?,
141            is_nullable_codes: Some(array.codes().dtype().is_nullable()),
142            all_values_referenced: Some(array.all_values_referenced),
143        }))
144    }
145
146    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
147        Ok(Some(metadata.serialize()))
148    }
149
150    fn deserialize(
151        bytes: &[u8],
152        _dtype: &DType,
153        _len: usize,
154        _buffers: &[BufferHandle],
155        _session: &VortexSession,
156    ) -> VortexResult<Self::Metadata> {
157        let metadata = <Self::Metadata as DeserializeMetadata>::deserialize(bytes)?;
158        Ok(ProstMetadata(metadata))
159    }
160
161    fn build(
162        dtype: &DType,
163        len: usize,
164        metadata: &Self::Metadata,
165        _buffers: &[BufferHandle],
166        children: &dyn ArrayChildren,
167    ) -> VortexResult<DictArray> {
168        if children.len() != 2 {
169            vortex_bail!(
170                "Expected 2 children for dict encoding, found {}",
171                children.len()
172            )
173        }
174        let codes_nullable = metadata
175            .is_nullable_codes
176            .map(Nullability::from)
177            // If no `is_nullable_codes` metadata use the nullability of the values
178            // (and whole array) as before.
179            .unwrap_or_else(|| dtype.nullability());
180        let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable);
181        let codes = children.get(0, &codes_dtype, len)?;
182        let values = children.get(1, dtype, metadata.values_len as usize)?;
183        let all_values_referenced = metadata.all_values_referenced.unwrap_or(false);
184
185        // SAFETY: We've validated the metadata and children.
186        Ok(unsafe {
187            DictArray::new_unchecked(codes, values).set_all_values_referenced(all_values_referenced)
188        })
189    }
190
191    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
192        vortex_ensure!(
193            children.len() == 2,
194            "DictArray expects exactly 2 children (codes, values), got {}",
195            children.len()
196        );
197        let [codes, values]: [ArrayRef; 2] = children
198            .try_into()
199            .map_err(|_| vortex_err!("Failed to convert children to array"))?;
200        array.codes = codes;
201        array.values = values;
202        Ok(())
203    }
204
205    fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
206        if array.is_empty() {
207            let result_dtype = array
208                .dtype()
209                .union_nullability(array.codes().dtype().nullability());
210            return Ok(ExecutionResult::done(Canonical::empty(&result_dtype)));
211        }
212
213        let array = require_child!(array, array.codes(), 0 => Primitive);
214
215        // TODO(joe): use stat get instead computing.
216        // Also not the check to do here it take value validity using code validity, but this approx
217        // is correct.
218        if array.codes().all_invalid()? {
219            return Ok(ExecutionResult::done(
220                ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.codes.len())
221                    .into_array(),
222            ));
223        }
224
225        let array = require_child!(array, array.values(), 1 => AnyCanonical);
226
227        let DictArrayParts { codes, values, .. } =
228            Arc::unwrap_or_clone(array).into_inner().into_parts();
229
230        let codes = codes
231            .try_into::<Primitive>()
232            .ok()
233            .vortex_expect("must be primitive");
234        debug_assert!(values.is_canonical());
235        // TODO: add canonical owned cast.
236        let values = values.to_canonical()?;
237
238        Ok(ExecutionResult::done(
239            take_canonical(values, &codes, ctx)?.into_array(),
240        ))
241    }
242
243    fn reduce_parent(
244        array: &Array<Self>,
245        parent: &ArrayRef,
246        child_idx: usize,
247    ) -> VortexResult<Option<ArrayRef>> {
248        PARENT_RULES.evaluate(array, parent, child_idx)
249    }
250
251    fn execute_parent(
252        array: &Array<Self>,
253        parent: &ArrayRef,
254        child_idx: usize,
255        ctx: &mut ExecutionCtx,
256    ) -> VortexResult<Option<ArrayRef>> {
257        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
258    }
259}