Skip to main content

vortex_array/arrays/dict/compute/
slice.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::Constant;
12use crate::arrays::ConstantArray;
13use crate::arrays::Dict;
14use crate::arrays::DictArray;
15use crate::arrays::dict::DictArraySlotsExt;
16use crate::arrays::slice::SliceReduce;
17use crate::scalar::Scalar;
18
19impl SliceReduce for Dict {
20    fn slice(array: ArrayView<'_, Self>, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
21        let sliced_code = array.codes().slice(range)?;
22        // TODO(joe): if the range is size 1 replace with a constant array
23        if let Some(code) = sliced_code.as_opt::<Constant>() {
24            let code = code.scalar().as_primitive().as_::<usize>();
25            return if let Some(code) = code {
26                let values = array.values().slice(code..code + 1)?;
27                Ok(Some(
28                    DictArray::new(
29                        ConstantArray::new(0u8, sliced_code.len()).into_array(),
30                        values,
31                    )
32                    .into_array(),
33                ))
34            } else {
35                Ok(Some(
36                    ConstantArray::new(Scalar::null(array.dtype().clone()), sliced_code.len())
37                        .into_array(),
38                ))
39            };
40        }
41        // SAFETY: slicing the codes preserves invariants.
42        Ok(Some(
43            unsafe { DictArray::new_unchecked(sliced_code, array.values().clone()) }.into_array(),
44        ))
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use vortex_buffer::buffer;
51    use vortex_error::VortexResult;
52
53    use crate::IntoArray;
54    use crate::arrays::DictArray;
55    use crate::arrays::PrimitiveArray;
56    use crate::arrays::dict::compute::slice::ConstantArray;
57    use crate::assert_arrays_eq;
58    use crate::dtype::DType;
59    use crate::dtype::Nullability::Nullable;
60    use crate::dtype::PType;
61    use crate::scalar::Scalar;
62
63    #[test]
64    fn slice_constant_valid_code() -> VortexResult<()> {
65        let dict = DictArray::new(
66            ConstantArray::new(1u8, 5).into_array(),
67            buffer![10i32, 20, 30].into_array(),
68        );
69        let sliced = dict.slice(1..4)?;
70        let expected = PrimitiveArray::from_iter([20i32, 20, 20]).into_array();
71        assert_arrays_eq!(sliced, expected);
72        Ok(())
73    }
74
75    #[test]
76    fn slice_constant_null_code() -> VortexResult<()> {
77        let dict = DictArray::new(
78            ConstantArray::new(Scalar::null(DType::Primitive(PType::U8, Nullable)), 5).into_array(),
79            buffer![10i32, 20, 30].into_array(),
80        );
81        let sliced = dict.slice(1..4)?;
82        let expected =
83            PrimitiveArray::from_option_iter([Option::<i32>::None, None, None]).into_array();
84        assert_arrays_eq!(sliced, expected);
85        Ok(())
86    }
87}