Skip to main content

vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use super::DictArray;
8use super::DictVTable;
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::builtins::ArrayBuiltins;
13use crate::compute::CastReduce;
14
15impl CastReduce for DictVTable {
16    fn cast(array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // Can have un-reference null values making the cast of values fail without a possible mask.
18        // TODO(joe): optimize this, could look at accessible values and fill_null not those?
19        if !dtype.is_nullable()
20            && array.values().dtype().is_nullable()
21            && !array.values().all_valid()?
22        {
23            return Ok(None);
24        }
25        // Cast the dictionary values to the target type
26        let casted_values = array.values().cast(dtype.clone())?;
27
28        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
29        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
30            array
31                .codes()
32                .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
33        } else {
34            array.codes().clone()
35        };
36
37        // SAFETY: casting does not alter invariants of the codes
38        Ok(Some(
39            unsafe {
40                DictArray::new_unchecked(casted_codes, casted_values)
41                    .set_all_values_referenced(array.has_all_values_referenced())
42            }
43            .into_array(),
44        ))
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use rstest::rstest;
51    use vortex_buffer::buffer;
52    use vortex_dtype::DType;
53    use vortex_dtype::Nullability;
54    use vortex_dtype::PType;
55
56    use crate::IntoArray;
57    use crate::ToCanonical;
58    use crate::arrays::PrimitiveArray;
59    use crate::arrays::dict::DictVTable;
60    use crate::assert_arrays_eq;
61    use crate::builders::dict::dict_encode;
62    use crate::builtins::ArrayBuiltins;
63    use crate::compute::conformance::cast::test_cast_conformance;
64
65    #[test]
66    fn test_cast_dict_to_wider_type() {
67        let values = buffer![1i32, 2, 3, 2, 1].into_array();
68        let dict = dict_encode(&values).unwrap();
69
70        let casted = dict
71            .to_array()
72            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
73            .unwrap();
74        assert_eq!(
75            casted.dtype(),
76            &DType::Primitive(PType::I64, Nullability::NonNullable)
77        );
78
79        let decoded = casted.to_primitive();
80        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
81    }
82
83    #[test]
84    fn test_cast_dict_nullable() {
85        let values =
86            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
87        let dict = dict_encode(values.as_ref()).unwrap();
88
89        let casted = dict
90            .to_array()
91            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
92            .unwrap();
93        assert_eq!(
94            casted.dtype(),
95            &DType::Primitive(PType::I64, Nullability::Nullable)
96        );
97    }
98
99    #[test]
100    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
101        // Create an AllValid dict array (no nulls)
102        let values = buffer![10i32, 20, 30, 40].into_array();
103        let dict = dict_encode(&values).unwrap();
104
105        // Verify initial state - codes should be NonNullable, values should be NonNullable
106        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
107        assert_eq!(
108            dict.values().dtype().nullability(),
109            Nullability::NonNullable
110        );
111
112        // Cast to NonNullable (should be identity since already NonNullable)
113        let non_nullable = dict
114            .to_array()
115            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
116            .unwrap();
117        assert_eq!(
118            non_nullable.dtype(),
119            &DType::Primitive(PType::I32, Nullability::NonNullable)
120        );
121
122        // Check that codes and values are still NonNullable
123        let non_nullable_dict = non_nullable.as_::<DictVTable>();
124        assert_eq!(
125            non_nullable_dict.codes().dtype().nullability(),
126            Nullability::NonNullable
127        );
128        assert_eq!(
129            non_nullable_dict.values().dtype().nullability(),
130            Nullability::NonNullable
131        );
132
133        // Cast to Nullable
134        let nullable = non_nullable
135            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
136            .unwrap();
137        assert_eq!(
138            nullable.dtype(),
139            &DType::Primitive(PType::I32, Nullability::Nullable)
140        );
141
142        // Check that both codes and values are now Nullable
143        let nullable_dict = nullable.as_::<DictVTable>();
144        assert_eq!(
145            nullable_dict.codes().dtype().nullability(),
146            Nullability::NonNullable
147        );
148        assert_eq!(
149            nullable_dict.values().dtype().nullability(),
150            Nullability::Nullable
151        );
152
153        // Cast back to NonNullable
154        let back_to_non_nullable = nullable
155            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
156            .unwrap();
157        assert_eq!(
158            back_to_non_nullable.dtype(),
159            &DType::Primitive(PType::I32, Nullability::NonNullable)
160        );
161
162        // Check that both codes and values are NonNullable again
163        let back_dict = back_to_non_nullable.as_::<DictVTable>();
164        assert_eq!(
165            back_dict.codes().dtype().nullability(),
166            Nullability::NonNullable
167        );
168        assert_eq!(
169            back_dict.values().dtype().nullability(),
170            Nullability::NonNullable
171        );
172
173        // Verify values are unchanged
174        let original_values = dict.to_primitive();
175        let final_values = back_dict.to_primitive();
176        assert_arrays_eq!(original_values, final_values);
177    }
178
179    #[rstest]
180    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
181    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
182    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
183    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
184    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
185        test_cast_conformance(array.as_ref());
186    }
187
188    #[test]
189    fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
190        use crate::arrays::dict::DictArray;
191        use crate::validity::Validity;
192
193        // Create a dict with nullable values that have unreferenced null entries.
194        // Values: [1.0, null, 3.0] (index 1 is null but no code points to it)
195        // Codes: [0, 2, 0] (only reference indices 0 and 2, never 1)
196        let values = PrimitiveArray::new(
197            buffer![1.0f64, 0.0f64, 3.0f64],
198            Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
199        )
200        .into_array();
201        let codes = buffer![0u32, 2, 0].into_array();
202        let dict = DictArray::try_new(codes, values).unwrap();
203
204        // The dict is Nullable (because values are nullable), but all codes point to valid values.
205        assert_eq!(
206            dict.dtype(),
207            &DType::Primitive(PType::F64, Nullability::Nullable)
208        );
209
210        // Casting to NonNullable should succeed since all logical values are non-null.
211        let result = dict
212            .to_array()
213            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
214        assert!(
215            result.is_ok(),
216            "cast to NonNullable should succeed for dict with only unreferenced null values"
217        );
218        let casted = result.unwrap();
219        assert_eq!(
220            casted.dtype(),
221            &DType::Primitive(PType::F64, Nullability::NonNullable)
222        );
223        assert_arrays_eq!(
224            casted.to_primitive(),
225            PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
226        );
227    }
228}