vortex_dict/
ops.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{ConstantArray, ConstantVTable};
5use vortex_array::vtable::OperationsVTable;
6use vortex_array::{Array, ArrayRef, IntoArray};
7use vortex_error::VortexExpect;
8use vortex_scalar::Scalar;
9
10use crate::{DictArray, DictVTable};
11
12impl OperationsVTable<DictVTable> for DictVTable {
13    fn slice(array: &DictArray, start: usize, stop: usize) -> ArrayRef {
14        let sliced_code = array.codes().slice(start, stop);
15        if sliced_code.is::<ConstantVTable>() {
16            let code = &sliced_code.scalar_at(0).as_primitive().as_::<usize>();
17            return if let Some(code) = code {
18                ConstantArray::new(array.values().scalar_at(*code), sliced_code.len()).into_array()
19            } else {
20                let dtype = array.values().dtype().with_nullability(
21                    array.values().dtype().nullability() | array.codes().dtype().nullability(),
22                );
23                ConstantArray::new(Scalar::null(dtype), sliced_code.len()).to_array()
24            };
25        }
26        // SAFETY: slicing the codes preserves invariants
27        unsafe { DictArray::new_unchecked(sliced_code, array.values().clone()).into_array() }
28    }
29
30    fn scalar_at(array: &DictArray, index: usize) -> Scalar {
31        let dict_index: usize = array
32            .codes()
33            .scalar_at(index)
34            .as_ref()
35            .try_into()
36            .vortex_expect("code overflowed usize");
37        array.values().scalar_at(dict_index)
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use vortex_array::arrays::PrimitiveArray;
44    use vortex_scalar::Scalar;
45
46    use crate::DictArray;
47
48    #[test]
49    fn test_slice_into_const_dict() {
50        let dict = DictArray::try_new(
51            PrimitiveArray::from_option_iter(vec![Some(0u32), None, Some(1)]).to_array(),
52            PrimitiveArray::from_option_iter(vec![Some(0i32), Some(1), Some(2)]).to_array(),
53        )
54        .unwrap();
55
56        assert_eq!(
57            Some(Scalar::new(dict.dtype().clone(), 0i32.into())),
58            dict.slice(0, 1).as_constant()
59        );
60
61        assert_eq!(
62            Some(Scalar::null(dict.dtype().clone())),
63            dict.slice(1, 2).as_constant()
64        );
65    }
66}