vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use super::DictArray;
8use super::DictVTable;
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::compute::CastKernel;
13use crate::compute::CastKernelAdapter;
14use crate::compute::cast;
15use crate::register_kernel;
16
17impl CastKernel for DictVTable {
18    fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        // Cast the dictionary values to the target type
20        let casted_values = cast(array.values(), dtype)?;
21
22        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
23        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
24            cast(
25                array.codes(),
26                &array.codes().dtype().with_nullability(dtype.nullability()),
27            )?
28        } else {
29            array.codes().clone()
30        };
31
32        // SAFETY: casting does not alter invariants of the codes
33        Ok(Some(
34            unsafe {
35                DictArray::new_unchecked(casted_codes, casted_values)
36                    .set_all_values_referenced(array.has_all_values_referenced())
37            }
38            .into_array(),
39        ))
40    }
41}
42
43register_kernel!(CastKernelAdapter(DictVTable).lift());
44
45#[cfg(test)]
46mod tests {
47    use rstest::rstest;
48    use vortex_buffer::buffer;
49    use vortex_dtype::DType;
50    use vortex_dtype::Nullability;
51    use vortex_dtype::PType;
52
53    use crate::IntoArray;
54    use crate::ToCanonical;
55    use crate::arrays::PrimitiveArray;
56    use crate::arrays::dict::DictVTable;
57    use crate::assert_arrays_eq;
58    use crate::builders::dict::dict_encode;
59    use crate::compute::cast;
60    use crate::compute::conformance::cast::test_cast_conformance;
61
62    #[test]
63    fn test_cast_dict_to_wider_type() {
64        let values = buffer![1i32, 2, 3, 2, 1].into_array();
65        let dict = dict_encode(&values).unwrap();
66
67        let casted = cast(
68            dict.as_ref(),
69            &DType::Primitive(PType::I64, Nullability::NonNullable),
70        )
71        .unwrap();
72        assert_eq!(
73            casted.dtype(),
74            &DType::Primitive(PType::I64, Nullability::NonNullable)
75        );
76
77        let decoded = casted.to_primitive();
78        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
79    }
80
81    #[test]
82    fn test_cast_dict_nullable() {
83        let values =
84            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
85        let dict = dict_encode(values.as_ref()).unwrap();
86
87        let casted = cast(
88            dict.as_ref(),
89            &DType::Primitive(PType::I64, Nullability::Nullable),
90        )
91        .unwrap();
92        assert_eq!(
93            casted.dtype(),
94            &DType::Primitive(PType::I64, Nullability::Nullable)
95        );
96    }
97
98    #[test]
99    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
100        // Create an AllValid dict array (no nulls)
101        let values = buffer![10i32, 20, 30, 40].into_array();
102        let dict = dict_encode(&values).unwrap();
103
104        // Verify initial state - codes should be NonNullable, values should be NonNullable
105        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
106        assert_eq!(
107            dict.values().dtype().nullability(),
108            Nullability::NonNullable
109        );
110
111        // Cast to NonNullable (should be identity since already NonNullable)
112        let non_nullable = cast(
113            dict.as_ref(),
114            &DType::Primitive(PType::I32, Nullability::NonNullable),
115        )
116        .unwrap();
117        assert_eq!(
118            non_nullable.dtype(),
119            &DType::Primitive(PType::I32, Nullability::NonNullable)
120        );
121
122        // Check that codes and values are still NonNullable
123        let non_nullable_dict = non_nullable.as_::<DictVTable>();
124        assert_eq!(
125            non_nullable_dict.codes().dtype().nullability(),
126            Nullability::NonNullable
127        );
128        assert_eq!(
129            non_nullable_dict.values().dtype().nullability(),
130            Nullability::NonNullable
131        );
132
133        // Cast to Nullable
134        let nullable = cast(
135            non_nullable.as_ref(),
136            &DType::Primitive(PType::I32, Nullability::Nullable),
137        )
138        .unwrap();
139        assert_eq!(
140            nullable.dtype(),
141            &DType::Primitive(PType::I32, Nullability::Nullable)
142        );
143
144        // Check that both codes and values are now Nullable
145        let nullable_dict = nullable.as_::<DictVTable>();
146        assert_eq!(
147            nullable_dict.codes().dtype().nullability(),
148            Nullability::NonNullable
149        );
150        assert_eq!(
151            nullable_dict.values().dtype().nullability(),
152            Nullability::Nullable
153        );
154
155        // Cast back to NonNullable
156        let back_to_non_nullable = cast(
157            nullable.as_ref(),
158            &DType::Primitive(PType::I32, Nullability::NonNullable),
159        )
160        .unwrap();
161        assert_eq!(
162            back_to_non_nullable.dtype(),
163            &DType::Primitive(PType::I32, Nullability::NonNullable)
164        );
165
166        // Check that both codes and values are NonNullable again
167        let back_dict = back_to_non_nullable.as_::<DictVTable>();
168        assert_eq!(
169            back_dict.codes().dtype().nullability(),
170            Nullability::NonNullable
171        );
172        assert_eq!(
173            back_dict.values().dtype().nullability(),
174            Nullability::NonNullable
175        );
176
177        // Verify values are unchanged
178        let original_values = dict.to_primitive();
179        let final_values = back_dict.to_primitive();
180        assert_arrays_eq!(original_values, final_values);
181    }
182
183    #[rstest]
184    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
185    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
186    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
187    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
188    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
189        test_cast_conformance(array.as_ref());
190    }
191}