vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use super::{DictArray, DictVTable};
8use crate::compute::{CastKernel, CastKernelAdapter, cast};
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl CastKernel for DictVTable {
12    fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        // Cast the dictionary values to the target type
14        let casted_values = cast(array.values(), dtype)?;
15
16        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
17        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
18            cast(
19                array.codes(),
20                &array.codes().dtype().with_nullability(dtype.nullability()),
21            )?
22        } else {
23            array.codes().clone()
24        };
25
26        // SAFETY: casting does not alter invariants of the codes
27        Ok(Some(
28            unsafe { DictArray::new_unchecked(casted_codes, casted_values) }.into_array(),
29        ))
30    }
31}
32
33register_kernel!(CastKernelAdapter(DictVTable).lift());
34
35#[cfg(test)]
36mod tests {
37    use rstest::rstest;
38    use vortex_buffer::buffer;
39    use vortex_dtype::{DType, Nullability, PType};
40
41    use crate::arrays::PrimitiveArray;
42    use crate::arrays::dict::DictVTable;
43    use crate::builders::dict::dict_encode;
44    use crate::compute::cast;
45    use crate::compute::conformance::cast::test_cast_conformance;
46    use crate::{IntoArray, ToCanonical, assert_arrays_eq};
47
48    #[test]
49    fn test_cast_dict_to_wider_type() {
50        let values = buffer![1i32, 2, 3, 2, 1].into_array();
51        let dict = dict_encode(&values).unwrap();
52
53        let casted = cast(
54            dict.as_ref(),
55            &DType::Primitive(PType::I64, Nullability::NonNullable),
56        )
57        .unwrap();
58        assert_eq!(
59            casted.dtype(),
60            &DType::Primitive(PType::I64, Nullability::NonNullable)
61        );
62
63        let decoded = casted.to_primitive();
64        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
65    }
66
67    #[test]
68    fn test_cast_dict_nullable() {
69        let values =
70            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
71        let dict = dict_encode(values.as_ref()).unwrap();
72
73        let casted = cast(
74            dict.as_ref(),
75            &DType::Primitive(PType::I64, Nullability::Nullable),
76        )
77        .unwrap();
78        assert_eq!(
79            casted.dtype(),
80            &DType::Primitive(PType::I64, Nullability::Nullable)
81        );
82    }
83
84    #[test]
85    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
86        // Create an AllValid dict array (no nulls)
87        let values = buffer![10i32, 20, 30, 40].into_array();
88        let dict = dict_encode(&values).unwrap();
89
90        // Verify initial state - codes should be NonNullable, values should be NonNullable
91        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
92        assert_eq!(
93            dict.values().dtype().nullability(),
94            Nullability::NonNullable
95        );
96
97        // Cast to NonNullable (should be identity since already NonNullable)
98        let non_nullable = cast(
99            dict.as_ref(),
100            &DType::Primitive(PType::I32, Nullability::NonNullable),
101        )
102        .unwrap();
103        assert_eq!(
104            non_nullable.dtype(),
105            &DType::Primitive(PType::I32, Nullability::NonNullable)
106        );
107
108        // Check that codes and values are still NonNullable
109        let non_nullable_dict = non_nullable.as_::<DictVTable>();
110        assert_eq!(
111            non_nullable_dict.codes().dtype().nullability(),
112            Nullability::NonNullable
113        );
114        assert_eq!(
115            non_nullable_dict.values().dtype().nullability(),
116            Nullability::NonNullable
117        );
118
119        // Cast to Nullable
120        let nullable = cast(
121            non_nullable.as_ref(),
122            &DType::Primitive(PType::I32, Nullability::Nullable),
123        )
124        .unwrap();
125        assert_eq!(
126            nullable.dtype(),
127            &DType::Primitive(PType::I32, Nullability::Nullable)
128        );
129
130        // Check that both codes and values are now Nullable
131        let nullable_dict = nullable.as_::<DictVTable>();
132        assert_eq!(
133            nullable_dict.codes().dtype().nullability(),
134            Nullability::NonNullable
135        );
136        assert_eq!(
137            nullable_dict.values().dtype().nullability(),
138            Nullability::Nullable
139        );
140
141        // Cast back to NonNullable
142        let back_to_non_nullable = cast(
143            nullable.as_ref(),
144            &DType::Primitive(PType::I32, Nullability::NonNullable),
145        )
146        .unwrap();
147        assert_eq!(
148            back_to_non_nullable.dtype(),
149            &DType::Primitive(PType::I32, Nullability::NonNullable)
150        );
151
152        // Check that both codes and values are NonNullable again
153        let back_dict = back_to_non_nullable.as_::<DictVTable>();
154        assert_eq!(
155            back_dict.codes().dtype().nullability(),
156            Nullability::NonNullable
157        );
158        assert_eq!(
159            back_dict.values().dtype().nullability(),
160            Nullability::NonNullable
161        );
162
163        // Verify values are unchanged
164        let original_values = dict.to_primitive();
165        let final_values = back_dict.to_primitive();
166        assert_arrays_eq!(original_values, final_values);
167    }
168
169    #[rstest]
170    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
171    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
172    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
173    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
174    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
175        test_cast_conformance(array.as_ref());
176    }
177}