Skip to main content

vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Dict {
18    fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        // Can have un-reference null values making the cast of values fail without a possible mask.
20        // TODO(joe): optimize this, could look at accessible values and fill_null not those?
21        if !dtype.is_nullable()
22            && array.values().dtype().is_nullable()
23            && !array.values().all_valid()?
24        {
25            return Ok(None);
26        }
27        // Cast the dictionary values to the target type
28        let casted_values = array.values().cast(dtype.clone())?;
29
30        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
31        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
32            array
33                .codes()
34                .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
35        } else {
36            array.codes().clone()
37        };
38
39        // SAFETY: casting does not alter invariants of the codes
40        Ok(Some(
41            unsafe {
42                DictArray::new_unchecked(casted_codes, casted_values)
43                    .set_all_values_referenced(array.has_all_values_referenced())
44            }
45            .into_array(),
46        ))
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use rstest::rstest;
53    use vortex_buffer::buffer;
54
55    use crate::IntoArray;
56    use crate::ToCanonical;
57    use crate::arrays::Dict;
58    use crate::arrays::PrimitiveArray;
59    use crate::arrays::dict::DictArraySlotsExt;
60    use crate::assert_arrays_eq;
61    use crate::builders::dict::dict_encode;
62    use crate::builtins::ArrayBuiltins;
63    use crate::compute::conformance::cast::test_cast_conformance;
64    use crate::dtype::DType;
65    use crate::dtype::Nullability;
66    use crate::dtype::PType;
67
68    #[test]
69    fn test_cast_dict_to_wider_type() {
70        let values = buffer![1i32, 2, 3, 2, 1].into_array();
71        let dict = dict_encode(&values).unwrap();
72
73        let casted = dict
74            .into_array()
75            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
76            .unwrap();
77        assert_eq!(
78            casted.dtype(),
79            &DType::Primitive(PType::I64, Nullability::NonNullable)
80        );
81
82        let decoded = casted.to_primitive();
83        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
84    }
85
86    #[test]
87    fn test_cast_dict_nullable() {
88        let values =
89            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
90        let dict = dict_encode(&values.into_array()).unwrap();
91
92        let casted = dict
93            .into_array()
94            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
95            .unwrap();
96        assert_eq!(
97            casted.dtype(),
98            &DType::Primitive(PType::I64, Nullability::Nullable)
99        );
100    }
101
102    #[test]
103    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
104        // Create an AllValid dict array (no nulls)
105        let values = buffer![10i32, 20, 30, 40].into_array();
106        let dict = dict_encode(&values).unwrap();
107
108        // Verify initial state - codes should be NonNullable, values should be NonNullable
109        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
110        assert_eq!(
111            dict.values().dtype().nullability(),
112            Nullability::NonNullable
113        );
114
115        // Cast to NonNullable (should be identity since already NonNullable)
116        let non_nullable = dict
117            .clone()
118            .into_array()
119            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
120            .unwrap();
121        assert_eq!(
122            non_nullable.dtype(),
123            &DType::Primitive(PType::I32, Nullability::NonNullable)
124        );
125
126        // Check that codes and values are still NonNullable
127        let non_nullable_dict = non_nullable.as_::<Dict>();
128        assert_eq!(
129            non_nullable_dict.codes().dtype().nullability(),
130            Nullability::NonNullable
131        );
132        assert_eq!(
133            non_nullable_dict.values().dtype().nullability(),
134            Nullability::NonNullable
135        );
136
137        // Cast to Nullable
138        let nullable = non_nullable
139            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
140            .unwrap();
141        assert_eq!(
142            nullable.dtype(),
143            &DType::Primitive(PType::I32, Nullability::Nullable)
144        );
145
146        // Check that both codes and values are now Nullable
147        let nullable_dict = nullable.as_::<Dict>();
148        assert_eq!(
149            nullable_dict.codes().dtype().nullability(),
150            Nullability::NonNullable
151        );
152        assert_eq!(
153            nullable_dict.values().dtype().nullability(),
154            Nullability::Nullable
155        );
156
157        // Cast back to NonNullable
158        let back_to_non_nullable = nullable
159            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
160            .unwrap();
161        assert_eq!(
162            back_to_non_nullable.dtype(),
163            &DType::Primitive(PType::I32, Nullability::NonNullable)
164        );
165
166        // Check that both codes and values are NonNullable again
167        let back_dict = back_to_non_nullable.as_::<Dict>();
168        assert_eq!(
169            back_dict.codes().dtype().nullability(),
170            Nullability::NonNullable
171        );
172        assert_eq!(
173            back_dict.values().dtype().nullability(),
174            Nullability::NonNullable
175        );
176
177        // Verify values are unchanged
178        let original_values = dict.as_array().to_primitive();
179        let final_values = back_dict.array().to_primitive();
180        assert_arrays_eq!(original_values, final_values);
181    }
182
183    #[rstest]
184    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
185    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
186    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
187    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
188    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
189        test_cast_conformance(&array);
190    }
191
192    #[test]
193    fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
194        use crate::arrays::DictArray;
195        use crate::validity::Validity;
196
197        // Create a dict with nullable values that have unreferenced null entries.
198        // Values: [1.0, null, 3.0] (index 1 is null but no code points to it)
199        // Codes: [0, 2, 0] (only reference indices 0 and 2, never 1)
200        let values = PrimitiveArray::new(
201            buffer![1.0f64, 0.0f64, 3.0f64],
202            Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
203        )
204        .into_array();
205        let codes = buffer![0u32, 2, 0].into_array();
206        let dict = DictArray::try_new(codes, values).unwrap();
207
208        // The dict is Nullable (because values are nullable), but all codes point to valid values.
209        assert_eq!(
210            dict.dtype(),
211            &DType::Primitive(PType::F64, Nullability::Nullable)
212        );
213
214        // Casting to NonNullable should succeed since all logical values are non-null.
215        let result = dict
216            .into_array()
217            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
218        assert!(
219            result.is_ok(),
220            "cast to NonNullable should succeed for dict with only unreferenced null values"
221        );
222        let casted = result.unwrap();
223        assert_eq!(
224            casted.dtype(),
225            &DType::Primitive(PType::F64, Nullability::NonNullable)
226        );
227        assert_arrays_eq!(
228            casted.to_primitive(),
229            PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
230        );
231    }
232}