vortex_dict/
ops.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{ConstantArray, ConstantVTable};
5use vortex_array::vtable::OperationsVTable;
6use vortex_array::{Array, ArrayRef, IntoArray};
7use vortex_error::VortexResult;
8use vortex_scalar::Scalar;
9
10use crate::{DictArray, DictVTable};
11
12impl OperationsVTable<DictVTable> for DictVTable {
13    fn slice(array: &DictArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
14        let sliced_code = array.codes().slice(start, stop)?;
15        if sliced_code.is::<ConstantVTable>() {
16            let code = Option::<usize>::try_from(&sliced_code.scalar_at(0)?)?;
17            return if let Some(code) = code {
18                Ok(
19                    ConstantArray::new(array.values().scalar_at(code)?, sliced_code.len())
20                        .to_array(),
21                )
22            } else {
23                let dtype = array.values().dtype().with_nullability(
24                    array.values().dtype().nullability() | array.codes().dtype().nullability(),
25                );
26                Ok(ConstantArray::new(Scalar::null(dtype), sliced_code.len()).to_array())
27            };
28        }
29        DictArray::try_new(sliced_code, array.values().clone()).map(|a| a.into_array())
30    }
31
32    fn scalar_at(array: &DictArray, index: usize) -> VortexResult<Scalar> {
33        let dict_index: usize = array.codes().scalar_at(index)?.as_ref().try_into()?;
34        array.values().scalar_at(dict_index)
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use vortex_array::arrays::PrimitiveArray;
41    use vortex_scalar::Scalar;
42
43    use crate::DictArray;
44
45    #[test]
46    fn test_slice_into_const_dict() {
47        let dict = DictArray::try_new(
48            PrimitiveArray::from_option_iter(vec![Some(0u32), None, Some(1)]).to_array(),
49            PrimitiveArray::from_option_iter(vec![Some(0i32), Some(1), Some(2)]).to_array(),
50        )
51        .unwrap();
52
53        assert_eq!(
54            Some(Scalar::new(dict.dtype().clone(), 0i32.into())),
55            dict.slice(0, 1).unwrap().as_constant()
56        );
57
58        assert_eq!(
59            Some(Scalar::null(dict.dtype().clone())),
60            dict.slice(1, 2).unwrap().as_constant()
61        );
62    }
63}