vortex_array/arrays/dict/
ops.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_error::VortexExpect;
7use vortex_scalar::Scalar;
8
9use super::{DictArray, DictVTable};
10use crate::arrays::{ConstantArray, ConstantVTable};
11use crate::vtable::OperationsVTable;
12use crate::{Array, ArrayRef, IntoArray};
13
14impl OperationsVTable<DictVTable> for DictVTable {
15    fn slice(array: &DictArray, range: Range<usize>) -> ArrayRef {
16        let sliced_code = array.codes().slice(range);
17        if sliced_code.is::<ConstantVTable>() {
18            let code = &sliced_code.scalar_at(0).as_primitive().as_::<usize>();
19            return if let Some(code) = code {
20                ConstantArray::new(array.values().scalar_at(*code), sliced_code.len()).into_array()
21            } else {
22                ConstantArray::new(Scalar::null(array.dtype().clone()), sliced_code.len())
23                    .to_array()
24            };
25        }
26        // SAFETY: slicing the codes preserves invariants
27        unsafe { DictArray::new_unchecked(sliced_code, array.values().clone()).into_array() }
28    }
29
30    fn scalar_at(array: &DictArray, index: usize) -> Scalar {
31        let Some(dict_index) = array.codes().scalar_at(index).as_primitive().as_::<usize>() else {
32            return Scalar::null(array.dtype().clone());
33        };
34
35        array
36            .values()
37            .scalar_at(dict_index)
38            .cast(array.dtype())
39            .vortex_expect("Array dtype will only differ by nullability")
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use vortex_buffer::buffer;
46    use vortex_scalar::Scalar;
47
48    use crate::arrays::PrimitiveArray;
49    use crate::arrays::dict::DictArray;
50    use crate::{IntoArray, assert_arrays_eq};
51
52    #[test]
53    fn test_slice_into_const_dict() {
54        let dict = DictArray::try_new(
55            PrimitiveArray::from_option_iter(vec![Some(0u32), None, Some(1)]).to_array(),
56            PrimitiveArray::from_option_iter(vec![Some(0i32), Some(1), Some(2)]).to_array(),
57        )
58        .unwrap();
59
60        assert_eq!(
61            Some(Scalar::new(dict.dtype().clone(), 0i32.into())),
62            dict.slice(0..1).as_constant()
63        );
64
65        assert_eq!(
66            Some(Scalar::null(dict.dtype().clone())),
67            dict.slice(1..2).as_constant()
68        );
69    }
70
71    #[test]
72    fn test_scalar_at_null_code() {
73        let dict = DictArray::try_new(
74            PrimitiveArray::from_option_iter(vec![None, Some(0u32), None]).to_array(),
75            buffer![1i32].into_array(),
76        )
77        .unwrap();
78
79        let expected = PrimitiveArray::from_option_iter(vec![None, Some(1i32), None]).into_array();
80        assert_arrays_eq!(dict, expected);
81    }
82}