vortex_dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{DictArray, DictVTable};
10
11impl CastKernel for DictVTable {
12    fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        // Cast the dictionary values to the target type
14        let casted_values = cast(array.values(), dtype)?;
15
16        let casted_codes = if dtype.nullability() != array.codes().dtype().nullability() {
17            cast(
18                array.codes(),
19                &array.codes().dtype().with_nullability(dtype.nullability()),
20            )?
21        } else {
22            array.codes().clone()
23        };
24
25        // Create a new dictionary array with the same codes but casted values
26        Ok(Some(
27            DictArray::try_new(casted_codes, casted_values)?.into_array(),
28        ))
29    }
30}
31
32register_kernel!(CastKernelAdapter(DictVTable).lift());
33
34#[cfg(test)]
35mod tests {
36    use rstest::rstest;
37    use vortex_array::IntoArray;
38    use vortex_array::arrays::PrimitiveArray;
39    use vortex_array::compute::cast;
40    use vortex_array::compute::conformance::cast::test_cast_conformance;
41    use vortex_buffer::buffer;
42    use vortex_dtype::{DType, Nullability, PType};
43
44    use crate::DictVTable;
45    use crate::builders::dict_encode;
46
47    #[test]
48    fn test_cast_dict_to_wider_type() {
49        let values = buffer![1i32, 2, 3, 2, 1].into_array();
50        let dict = dict_encode(&values).unwrap();
51
52        let casted = cast(
53            dict.as_ref(),
54            &DType::Primitive(PType::I64, Nullability::NonNullable),
55        )
56        .unwrap();
57        assert_eq!(
58            casted.dtype(),
59            &DType::Primitive(PType::I64, Nullability::NonNullable)
60        );
61
62        let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
63        assert_eq!(decoded.as_slice::<i64>(), &[1i64, 2, 3, 2, 1]);
64    }
65
66    #[test]
67    fn test_cast_dict_nullable() {
68        let values =
69            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
70        let dict = dict_encode(values.as_ref()).unwrap();
71
72        let casted = cast(
73            dict.as_ref(),
74            &DType::Primitive(PType::I64, Nullability::Nullable),
75        )
76        .unwrap();
77        assert_eq!(
78            casted.dtype(),
79            &DType::Primitive(PType::I64, Nullability::Nullable)
80        );
81    }
82
83    #[test]
84    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
85        // Create an AllValid dict array (no nulls)
86        let values = buffer![10i32, 20, 30, 40].into_array();
87        let dict = dict_encode(&values).unwrap();
88
89        // Verify initial state - codes should be NonNullable, values should be NonNullable
90        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
91        assert_eq!(
92            dict.values().dtype().nullability(),
93            Nullability::NonNullable
94        );
95
96        // Cast to NonNullable (should be identity since already NonNullable)
97        let non_nullable = cast(
98            dict.as_ref(),
99            &DType::Primitive(PType::I32, Nullability::NonNullable),
100        )
101        .unwrap();
102        assert_eq!(
103            non_nullable.dtype(),
104            &DType::Primitive(PType::I32, Nullability::NonNullable)
105        );
106
107        // Check that codes and values are still NonNullable
108        let non_nullable_dict = non_nullable.as_::<DictVTable>();
109        assert_eq!(
110            non_nullable_dict.codes().dtype().nullability(),
111            Nullability::NonNullable
112        );
113        assert_eq!(
114            non_nullable_dict.values().dtype().nullability(),
115            Nullability::NonNullable
116        );
117
118        // Cast to Nullable
119        let nullable = cast(
120            non_nullable.as_ref(),
121            &DType::Primitive(PType::I32, Nullability::Nullable),
122        )
123        .unwrap();
124        assert_eq!(
125            nullable.dtype(),
126            &DType::Primitive(PType::I32, Nullability::Nullable)
127        );
128
129        // Check that both codes and values are now Nullable
130        let nullable_dict = nullable.as_::<DictVTable>();
131        assert_eq!(
132            nullable_dict.codes().dtype().nullability(),
133            Nullability::Nullable
134        );
135        assert_eq!(
136            nullable_dict.values().dtype().nullability(),
137            Nullability::Nullable
138        );
139
140        // Cast back to NonNullable
141        let back_to_non_nullable = cast(
142            nullable.as_ref(),
143            &DType::Primitive(PType::I32, Nullability::NonNullable),
144        )
145        .unwrap();
146        assert_eq!(
147            back_to_non_nullable.dtype(),
148            &DType::Primitive(PType::I32, Nullability::NonNullable)
149        );
150
151        // Check that both codes and values are NonNullable again
152        let back_dict = back_to_non_nullable.as_::<DictVTable>();
153        assert_eq!(
154            back_dict.codes().dtype().nullability(),
155            Nullability::NonNullable
156        );
157        assert_eq!(
158            back_dict.values().dtype().nullability(),
159            Nullability::NonNullable
160        );
161
162        // Verify values are unchanged
163        let original_values = dict.to_canonical().unwrap().into_primitive().unwrap();
164        let final_values = back_dict.to_canonical().unwrap().into_primitive().unwrap();
165        assert_eq!(
166            original_values.as_slice::<i32>(),
167            final_values.as_slice::<i32>()
168        );
169    }
170
171    #[rstest]
172    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
173    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
174    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
175    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
176    fn test_cast_dict_conformance(#[case] array: vortex_array::ArrayRef) {
177        test_cast_conformance(array.as_ref());
178    }
179}