Skip to main content

vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16use crate::validity::Validity;
17
18impl CastReduce for Dict {
19    fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20        // Can have un-reference null values making the cast of values fail without a possible mask.
21        // TODO(joe): optimize this, could look at accessible values and fill_null not those?
22        if !dtype.is_nullable()
23            && array.values().dtype().is_nullable()
24            && !matches!(
25                array.values().validity()?,
26                Validity::NonNullable | Validity::AllValid
27            )
28        {
29            return Ok(None);
30        }
31        // Cast the dictionary values to the target type
32        let casted_values = array.values().cast(dtype.clone())?;
33
34        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
35        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
36            array
37                .codes()
38                .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
39        } else {
40            array.codes().clone()
41        };
42
43        // SAFETY: casting does not alter invariants of the codes
44        Ok(Some(
45            unsafe {
46                DictArray::new_unchecked(casted_codes, casted_values)
47                    .set_all_values_referenced(array.has_all_values_referenced())
48            }
49            .into_array(),
50        ))
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use rstest::rstest;
57    use vortex_buffer::buffer;
58
59    use crate::IntoArray;
60    use crate::ToCanonical;
61    use crate::arrays::Dict;
62    use crate::arrays::PrimitiveArray;
63    use crate::arrays::dict::DictArraySlotsExt;
64    use crate::assert_arrays_eq;
65    use crate::builders::dict::dict_encode;
66    use crate::builtins::ArrayBuiltins;
67    use crate::compute::conformance::cast::test_cast_conformance;
68    use crate::dtype::DType;
69    use crate::dtype::Nullability;
70    use crate::dtype::PType;
71
72    #[test]
73    fn test_cast_dict_to_wider_type() {
74        let values = buffer![1i32, 2, 3, 2, 1].into_array();
75        let dict = dict_encode(&values).unwrap();
76
77        let casted = dict
78            .into_array()
79            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
80            .unwrap();
81        assert_eq!(
82            casted.dtype(),
83            &DType::Primitive(PType::I64, Nullability::NonNullable)
84        );
85
86        let decoded = casted.to_primitive();
87        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
88    }
89
90    #[test]
91    fn test_cast_dict_nullable() {
92        let values =
93            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
94        let dict = dict_encode(&values.into_array()).unwrap();
95
96        let casted = dict
97            .into_array()
98            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
99            .unwrap();
100        assert_eq!(
101            casted.dtype(),
102            &DType::Primitive(PType::I64, Nullability::Nullable)
103        );
104    }
105
106    #[test]
107    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
108        // Create an AllValid dict array (no nulls)
109        let values = buffer![10i32, 20, 30, 40].into_array();
110        let dict = dict_encode(&values).unwrap();
111
112        // Verify initial state - codes should be NonNullable, values should be NonNullable
113        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
114        assert_eq!(
115            dict.values().dtype().nullability(),
116            Nullability::NonNullable
117        );
118
119        // Cast to NonNullable (should be identity since already NonNullable)
120        let non_nullable = dict
121            .clone()
122            .into_array()
123            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
124            .unwrap();
125        assert_eq!(
126            non_nullable.dtype(),
127            &DType::Primitive(PType::I32, Nullability::NonNullable)
128        );
129
130        // Check that codes and values are still NonNullable
131        let non_nullable_dict = non_nullable.as_::<Dict>();
132        assert_eq!(
133            non_nullable_dict.codes().dtype().nullability(),
134            Nullability::NonNullable
135        );
136        assert_eq!(
137            non_nullable_dict.values().dtype().nullability(),
138            Nullability::NonNullable
139        );
140
141        // Cast to Nullable
142        let nullable = non_nullable
143            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
144            .unwrap();
145        assert_eq!(
146            nullable.dtype(),
147            &DType::Primitive(PType::I32, Nullability::Nullable)
148        );
149
150        // Check that both codes and values are now Nullable
151        let nullable_dict = nullable.as_::<Dict>();
152        assert_eq!(
153            nullable_dict.codes().dtype().nullability(),
154            Nullability::NonNullable
155        );
156        assert_eq!(
157            nullable_dict.values().dtype().nullability(),
158            Nullability::Nullable
159        );
160
161        // Cast back to NonNullable
162        let back_to_non_nullable = nullable
163            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
164            .unwrap();
165        assert_eq!(
166            back_to_non_nullable.dtype(),
167            &DType::Primitive(PType::I32, Nullability::NonNullable)
168        );
169
170        // Verify values are unchanged
171        let original_values = dict.as_array().to_primitive();
172        let final_values = back_to_non_nullable.to_primitive();
173        assert_arrays_eq!(original_values, final_values);
174    }
175
176    #[rstest]
177    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
178    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
179    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
180    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
181    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
182        test_cast_conformance(&array);
183    }
184
185    #[test]
186    fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
187        use crate::arrays::DictArray;
188        use crate::validity::Validity;
189
190        // Create a dict with nullable values that have unreferenced null entries.
191        // Values: [1.0, null, 3.0] (index 1 is null but no code points to it)
192        // Codes: [0, 2, 0] (only reference indices 0 and 2, never 1)
193        let values = PrimitiveArray::new(
194            buffer![1.0f64, 0.0f64, 3.0f64],
195            Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
196        )
197        .into_array();
198        let codes = buffer![0u32, 2, 0].into_array();
199        let dict = DictArray::try_new(codes, values).unwrap();
200
201        // The dict is Nullable (because values are nullable), but all codes point to valid values.
202        assert_eq!(
203            dict.dtype(),
204            &DType::Primitive(PType::F64, Nullability::Nullable)
205        );
206
207        // Casting to NonNullable should succeed since all logical values are non-null.
208        let result = dict
209            .into_array()
210            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
211        assert!(
212            result.is_ok(),
213            "cast to NonNullable should succeed for dict with only unreferenced null values"
214        );
215        let casted = result.unwrap();
216        assert_eq!(
217            casted.dtype(),
218            &DType::Primitive(PType::F64, Nullability::NonNullable)
219        );
220        assert_arrays_eq!(
221            casted.to_primitive(),
222            PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
223        );
224    }
225}