Skip to main content

vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::DictArray;
7use super::DictVTable;
8use crate::ArrayRef;
9use crate::DynArray;
10use crate::IntoArray;
11use crate::builtins::ArrayBuiltins;
12use crate::dtype::DType;
13use crate::scalar_fn::fns::cast::CastReduce;
14
15impl CastReduce for DictVTable {
16    fn cast(array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // Can have un-reference null values making the cast of values fail without a possible mask.
18        // TODO(joe): optimize this, could look at accessible values and fill_null not those?
19        if !dtype.is_nullable()
20            && array.values().dtype().is_nullable()
21            && !array.values().all_valid()?
22        {
23            return Ok(None);
24        }
25        // Cast the dictionary values to the target type
26        let casted_values = array.values().cast(dtype.clone())?;
27
28        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
29        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
30            array
31                .codes()
32                .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
33        } else {
34            array.codes().clone()
35        };
36
37        // SAFETY: casting does not alter invariants of the codes
38        Ok(Some(
39            unsafe {
40                DictArray::new_unchecked(casted_codes, casted_values)
41                    .set_all_values_referenced(array.has_all_values_referenced())
42            }
43            .into_array(),
44        ))
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use rstest::rstest;
51    use vortex_buffer::buffer;
52
53    use crate::IntoArray;
54    use crate::ToCanonical;
55    use crate::arrays::DictVTable;
56    use crate::arrays::PrimitiveArray;
57    use crate::assert_arrays_eq;
58    use crate::builders::dict::dict_encode;
59    use crate::builtins::ArrayBuiltins;
60    use crate::compute::conformance::cast::test_cast_conformance;
61    use crate::dtype::DType;
62    use crate::dtype::Nullability;
63    use crate::dtype::PType;
64
65    #[test]
66    fn test_cast_dict_to_wider_type() {
67        let values = buffer![1i32, 2, 3, 2, 1].into_array();
68        let dict = dict_encode(&values).unwrap();
69
70        let casted = dict
71            .into_array()
72            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
73            .unwrap();
74        assert_eq!(
75            casted.dtype(),
76            &DType::Primitive(PType::I64, Nullability::NonNullable)
77        );
78
79        let decoded = casted.to_primitive();
80        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
81    }
82
83    #[test]
84    fn test_cast_dict_nullable() {
85        let values =
86            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
87        let dict = dict_encode(&values.into_array()).unwrap();
88
89        let casted = dict
90            .into_array()
91            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
92            .unwrap();
93        assert_eq!(
94            casted.dtype(),
95            &DType::Primitive(PType::I64, Nullability::Nullable)
96        );
97    }
98
99    #[test]
100    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
101        // Create an AllValid dict array (no nulls)
102        let values = buffer![10i32, 20, 30, 40].into_array();
103        let dict = dict_encode(&values).unwrap();
104
105        // Verify initial state - codes should be NonNullable, values should be NonNullable
106        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
107        assert_eq!(
108            dict.values().dtype().nullability(),
109            Nullability::NonNullable
110        );
111
112        // Cast to NonNullable (should be identity since already NonNullable)
113        let non_nullable = dict
114            .clone()
115            .into_array()
116            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
117            .unwrap();
118        assert_eq!(
119            non_nullable.dtype(),
120            &DType::Primitive(PType::I32, Nullability::NonNullable)
121        );
122
123        // Check that codes and values are still NonNullable
124        let non_nullable_dict = non_nullable.as_::<DictVTable>();
125        assert_eq!(
126            non_nullable_dict.codes().dtype().nullability(),
127            Nullability::NonNullable
128        );
129        assert_eq!(
130            non_nullable_dict.values().dtype().nullability(),
131            Nullability::NonNullable
132        );
133
134        // Cast to Nullable
135        let nullable = non_nullable
136            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
137            .unwrap();
138        assert_eq!(
139            nullable.dtype(),
140            &DType::Primitive(PType::I32, Nullability::Nullable)
141        );
142
143        // Check that both codes and values are now Nullable
144        let nullable_dict = nullable.as_::<DictVTable>();
145        assert_eq!(
146            nullable_dict.codes().dtype().nullability(),
147            Nullability::NonNullable
148        );
149        assert_eq!(
150            nullable_dict.values().dtype().nullability(),
151            Nullability::Nullable
152        );
153
154        // Cast back to NonNullable
155        let back_to_non_nullable = nullable
156            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
157            .unwrap();
158        assert_eq!(
159            back_to_non_nullable.dtype(),
160            &DType::Primitive(PType::I32, Nullability::NonNullable)
161        );
162
163        // Check that both codes and values are NonNullable again
164        let back_dict = back_to_non_nullable.as_::<DictVTable>();
165        assert_eq!(
166            back_dict.codes().dtype().nullability(),
167            Nullability::NonNullable
168        );
169        assert_eq!(
170            back_dict.values().dtype().nullability(),
171            Nullability::NonNullable
172        );
173
174        // Verify values are unchanged
175        let original_values = dict.to_primitive();
176        let final_values = back_dict.to_primitive();
177        assert_arrays_eq!(original_values, final_values);
178    }
179
180    #[rstest]
181    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
182    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
183    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
184    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
185    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
186        test_cast_conformance(&array);
187    }
188
189    #[test]
190    fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
191        use crate::arrays::DictArray;
192        use crate::validity::Validity;
193
194        // Create a dict with nullable values that have unreferenced null entries.
195        // Values: [1.0, null, 3.0] (index 1 is null but no code points to it)
196        // Codes: [0, 2, 0] (only reference indices 0 and 2, never 1)
197        let values = PrimitiveArray::new(
198            buffer![1.0f64, 0.0f64, 3.0f64],
199            Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
200        )
201        .into_array();
202        let codes = buffer![0u32, 2, 0].into_array();
203        let dict = DictArray::try_new(codes, values).unwrap();
204
205        // The dict is Nullable (because values are nullable), but all codes point to valid values.
206        assert_eq!(
207            dict.dtype(),
208            &DType::Primitive(PType::F64, Nullability::Nullable)
209        );
210
211        // Casting to NonNullable should succeed since all logical values are non-null.
212        let result = dict
213            .into_array()
214            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
215        assert!(
216            result.is_ok(),
217            "cast to NonNullable should succeed for dict with only unreferenced null values"
218        );
219        let casted = result.unwrap();
220        assert_eq!(
221            casted.dtype(),
222            &DType::Primitive(PType::F64, Nullability::NonNullable)
223        );
224        assert_arrays_eq!(
225            casted.to_primitive(),
226            PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
227        );
228    }
229}