vortex_dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{DictArray, DictVTable};
10
11impl CastKernel for DictVTable {
12    fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        // Cast the dictionary values to the target type
14        let casted_values = cast(array.values(), dtype)?;
15
16        let casted_codes = if dtype.nullability() != array.codes().dtype().nullability() {
17            cast(
18                array.codes(),
19                &array.codes().dtype().with_nullability(dtype.nullability()),
20            )?
21        } else {
22            array.codes().clone()
23        };
24
25        // SAFETY: casting does not alter invariants of the codes
26        unsafe {
27            Ok(Some(
28                DictArray::new_unchecked(casted_codes, casted_values).into_array(),
29            ))
30        }
31    }
32}
33
34register_kernel!(CastKernelAdapter(DictVTable).lift());
35
36#[cfg(test)]
37mod tests {
38    use rstest::rstest;
39    use vortex_array::arrays::PrimitiveArray;
40    use vortex_array::compute::cast;
41    use vortex_array::compute::conformance::cast::test_cast_conformance;
42    use vortex_array::{IntoArray, ToCanonical};
43    use vortex_buffer::buffer;
44    use vortex_dtype::{DType, Nullability, PType};
45
46    use crate::DictVTable;
47    use crate::builders::dict_encode;
48
49    #[test]
50    fn test_cast_dict_to_wider_type() {
51        let values = buffer![1i32, 2, 3, 2, 1].into_array();
52        let dict = dict_encode(&values).unwrap();
53
54        let casted = cast(
55            dict.as_ref(),
56            &DType::Primitive(PType::I64, Nullability::NonNullable),
57        )
58        .unwrap();
59        assert_eq!(
60            casted.dtype(),
61            &DType::Primitive(PType::I64, Nullability::NonNullable)
62        );
63
64        let decoded = casted.to_primitive();
65        assert_eq!(decoded.as_slice::<i64>(), &[1i64, 2, 3, 2, 1]);
66    }
67
68    #[test]
69    fn test_cast_dict_nullable() {
70        let values =
71            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
72        let dict = dict_encode(values.as_ref()).unwrap();
73
74        let casted = cast(
75            dict.as_ref(),
76            &DType::Primitive(PType::I64, Nullability::Nullable),
77        )
78        .unwrap();
79        assert_eq!(
80            casted.dtype(),
81            &DType::Primitive(PType::I64, Nullability::Nullable)
82        );
83    }
84
85    #[test]
86    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
87        // Create an AllValid dict array (no nulls)
88        let values = buffer![10i32, 20, 30, 40].into_array();
89        let dict = dict_encode(&values).unwrap();
90
91        // Verify initial state - codes should be NonNullable, values should be NonNullable
92        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
93        assert_eq!(
94            dict.values().dtype().nullability(),
95            Nullability::NonNullable
96        );
97
98        // Cast to NonNullable (should be identity since already NonNullable)
99        let non_nullable = cast(
100            dict.as_ref(),
101            &DType::Primitive(PType::I32, Nullability::NonNullable),
102        )
103        .unwrap();
104        assert_eq!(
105            non_nullable.dtype(),
106            &DType::Primitive(PType::I32, Nullability::NonNullable)
107        );
108
109        // Check that codes and values are still NonNullable
110        let non_nullable_dict = non_nullable.as_::<DictVTable>();
111        assert_eq!(
112            non_nullable_dict.codes().dtype().nullability(),
113            Nullability::NonNullable
114        );
115        assert_eq!(
116            non_nullable_dict.values().dtype().nullability(),
117            Nullability::NonNullable
118        );
119
120        // Cast to Nullable
121        let nullable = cast(
122            non_nullable.as_ref(),
123            &DType::Primitive(PType::I32, Nullability::Nullable),
124        )
125        .unwrap();
126        assert_eq!(
127            nullable.dtype(),
128            &DType::Primitive(PType::I32, Nullability::Nullable)
129        );
130
131        // Check that both codes and values are now Nullable
132        let nullable_dict = nullable.as_::<DictVTable>();
133        assert_eq!(
134            nullable_dict.codes().dtype().nullability(),
135            Nullability::Nullable
136        );
137        assert_eq!(
138            nullable_dict.values().dtype().nullability(),
139            Nullability::Nullable
140        );
141
142        // Cast back to NonNullable
143        let back_to_non_nullable = cast(
144            nullable.as_ref(),
145            &DType::Primitive(PType::I32, Nullability::NonNullable),
146        )
147        .unwrap();
148        assert_eq!(
149            back_to_non_nullable.dtype(),
150            &DType::Primitive(PType::I32, Nullability::NonNullable)
151        );
152
153        // Check that both codes and values are NonNullable again
154        let back_dict = back_to_non_nullable.as_::<DictVTable>();
155        assert_eq!(
156            back_dict.codes().dtype().nullability(),
157            Nullability::NonNullable
158        );
159        assert_eq!(
160            back_dict.values().dtype().nullability(),
161            Nullability::NonNullable
162        );
163
164        // Verify values are unchanged
165        let original_values = dict.to_primitive();
166        let final_values = back_dict.to_primitive();
167        assert_eq!(
168            original_values.as_slice::<i32>(),
169            final_values.as_slice::<i32>()
170        );
171    }
172
173    #[rstest]
174    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
175    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
176    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
177    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
178    fn test_cast_dict_conformance(#[case] array: vortex_array::ArrayRef) {
179        test_cast_conformance(array.as_ref());
180    }
181}