Skip to main content

vortex_array/arrays/dict/compute/
slice.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::arrays::ConstantArray;
11use crate::arrays::ConstantVTable;
12use crate::arrays::DictArray;
13use crate::arrays::DictVTable;
14use crate::arrays::SliceReduce;
15use crate::scalar::Scalar;
16
17impl SliceReduce for DictVTable {
18    fn slice(array: &Self::Array, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
19        let sliced_code = array.codes().slice(range)?;
20        // TODO(joe): if the range is size 1 replace with a constant array
21        if let Some(code) = sliced_code.as_opt::<ConstantVTable>() {
22            let code = code.scalar().as_primitive().as_::<usize>();
23            return if let Some(code) = code {
24                let values = array.values().slice(code..code + 1)?;
25                Ok(Some(
26                    DictArray::new(
27                        ConstantArray::new(0u8, sliced_code.len()).into_array(),
28                        values,
29                    )
30                    .into_array(),
31                ))
32            } else {
33                Ok(Some(
34                    ConstantArray::new(Scalar::null(array.dtype().clone()), sliced_code.len())
35                        .to_array(),
36                ))
37            };
38        }
39        // SAFETY: slicing the codes preserves invariants.
40        Ok(Some(
41            unsafe { DictArray::new_unchecked(sliced_code, array.values().clone()) }.into_array(),
42        ))
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use vortex_buffer::buffer;
49    use vortex_dtype::DType;
50    use vortex_dtype::Nullability::Nullable;
51    use vortex_dtype::PType;
52    use vortex_error::VortexResult;
53
54    use crate::Array;
55    use crate::IntoArray;
56    use crate::arrays::ConstantArray;
57    use crate::arrays::DictArray;
58    use crate::arrays::PrimitiveArray;
59    use crate::assert_arrays_eq;
60    use crate::scalar::Scalar;
61
62    #[test]
63    fn slice_constant_valid_code() -> VortexResult<()> {
64        let dict = DictArray::new(
65            ConstantArray::new(1u8, 5).into_array(),
66            buffer![10i32, 20, 30].into_array(),
67        );
68        let sliced = dict.slice(1..4)?;
69        let expected = PrimitiveArray::from_iter([20i32, 20, 20]).into_array();
70        assert_arrays_eq!(sliced, expected);
71        Ok(())
72    }
73
74    #[test]
75    fn slice_constant_null_code() -> VortexResult<()> {
76        let dict = DictArray::new(
77            ConstantArray::new(Scalar::null(DType::Primitive(PType::U8, Nullable)), 5).into_array(),
78            buffer![10i32, 20, 30].into_array(),
79        );
80        let sliced = dict.slice(1..4)?;
81        let expected =
82            PrimitiveArray::from_option_iter([Option::<i32>::None, None, None]).into_array();
83        assert_arrays_eq!(sliced, expected);
84        Ok(())
85    }
86}