Skip to main content

vortex_array/arrays/dict/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Dict {
18    fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        // Can have un-reference null values making the cast of values fail without a possible mask.
20        // TODO(joe): optimize this, could look at accessible values and fill_null not those?
21        if !dtype.is_nullable() && !array.values().validity()?.no_nulls() {
22            return Ok(None);
23        }
24        // Cast the dictionary values to the target type
25        let casted_values = array.values().cast(dtype.clone())?;
26
27        // If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
28        let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
29            array
30                .codes()
31                .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
32        } else {
33            array.codes().clone()
34        };
35
36        // SAFETY: casting does not alter invariants of the codes
37        Ok(Some(
38            unsafe {
39                DictArray::new_unchecked(casted_codes, casted_values)
40                    .set_all_values_referenced(array.has_all_values_referenced())
41            }
42            .into_array(),
43        ))
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use std::sync::LazyLock;
50
51    use rstest::rstest;
52    use vortex_buffer::buffer;
53    use vortex_session::VortexSession;
54
55    use crate::IntoArray;
56    #[expect(deprecated)]
57    use crate::ToCanonical as _;
58    use crate::VortexSessionExecute;
59    use crate::arrays::Dict;
60    use crate::arrays::PrimitiveArray;
61    use crate::arrays::dict::DictArraySlotsExt;
62    use crate::assert_arrays_eq;
63    use crate::builders::dict::dict_encode;
64    use crate::builtins::ArrayBuiltins;
65    use crate::compute::conformance::cast::test_cast_conformance;
66    use crate::dtype::DType;
67    use crate::dtype::Nullability;
68    use crate::dtype::PType;
69    use crate::session::ArraySession;
70
71    static SESSION: LazyLock<VortexSession> =
72        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
73
74    #[test]
75    fn test_cast_dict_to_wider_type() {
76        let values = buffer![1i32, 2, 3, 2, 1].into_array();
77        let dict = dict_encode(&values, &mut SESSION.create_execution_ctx()).unwrap();
78
79        let casted = dict
80            .into_array()
81            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
82            .unwrap();
83        assert_eq!(
84            casted.dtype(),
85            &DType::Primitive(PType::I64, Nullability::NonNullable)
86        );
87
88        #[expect(deprecated)]
89        let decoded = casted.to_primitive();
90        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
91    }
92
93    #[test]
94    fn test_cast_dict_nullable() {
95        let values =
96            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
97        let dict = dict_encode(&values.into_array(), &mut SESSION.create_execution_ctx()).unwrap();
98
99        let casted = dict
100            .into_array()
101            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
102            .unwrap();
103        assert_eq!(
104            casted.dtype(),
105            &DType::Primitive(PType::I64, Nullability::Nullable)
106        );
107    }
108
109    #[test]
110    fn test_cast_dict_allvalid_to_nonnullable_and_back() {
111        // Create an AllValid dict array (no nulls)
112        let values = buffer![10i32, 20, 30, 40].into_array();
113        let dict = dict_encode(&values, &mut SESSION.create_execution_ctx()).unwrap();
114
115        // Verify initial state - codes should be NonNullable, values should be NonNullable
116        assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
117        assert_eq!(
118            dict.values().dtype().nullability(),
119            Nullability::NonNullable
120        );
121
122        // Cast to NonNullable (should be identity since already NonNullable)
123        let non_nullable = dict
124            .clone()
125            .into_array()
126            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
127            .unwrap();
128        assert_eq!(
129            non_nullable.dtype(),
130            &DType::Primitive(PType::I32, Nullability::NonNullable)
131        );
132
133        // Check that codes and values are still NonNullable
134        let non_nullable_dict = non_nullable.as_::<Dict>();
135        assert_eq!(
136            non_nullable_dict.codes().dtype().nullability(),
137            Nullability::NonNullable
138        );
139        assert_eq!(
140            non_nullable_dict.values().dtype().nullability(),
141            Nullability::NonNullable
142        );
143
144        // Cast to Nullable
145        let nullable = non_nullable
146            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
147            .unwrap();
148        assert_eq!(
149            nullable.dtype(),
150            &DType::Primitive(PType::I32, Nullability::Nullable)
151        );
152
153        // Check that both codes and values are now Nullable
154        let nullable_dict = nullable.as_::<Dict>();
155        assert_eq!(
156            nullable_dict.codes().dtype().nullability(),
157            Nullability::NonNullable
158        );
159        assert_eq!(
160            nullable_dict.values().dtype().nullability(),
161            Nullability::Nullable
162        );
163
164        // Cast back to NonNullable
165        let back_to_non_nullable = nullable
166            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
167            .unwrap();
168        assert_eq!(
169            back_to_non_nullable.dtype(),
170            &DType::Primitive(PType::I32, Nullability::NonNullable)
171        );
172
173        // Verify values are unchanged
174        #[expect(deprecated)]
175        let original_values = dict.as_array().to_primitive();
176        #[expect(deprecated)]
177        let final_values = back_to_non_nullable.to_primitive();
178        assert_arrays_eq!(original_values, final_values);
179    }
180
181    #[rstest]
182    #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
183    #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
184    #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
185    #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
186    fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
187        test_cast_conformance(&array);
188    }
189
190    #[test]
191    fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
192        use crate::arrays::DictArray;
193        use crate::validity::Validity;
194
195        // Create a dict with nullable values that have unreferenced null entries.
196        // Values: [1.0, null, 3.0] (index 1 is null but no code points to it)
197        // Codes: [0, 2, 0] (only reference indices 0 and 2, never 1)
198        let values = PrimitiveArray::new(
199            buffer![1.0f64, 0.0f64, 3.0f64],
200            Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
201        )
202        .into_array();
203        let codes = buffer![0u32, 2, 0].into_array();
204        let dict = DictArray::try_new(codes, values).unwrap();
205
206        // The dict is Nullable (because values are nullable), but all codes point to valid values.
207        assert_eq!(
208            dict.dtype(),
209            &DType::Primitive(PType::F64, Nullability::Nullable)
210        );
211
212        // Casting to NonNullable should succeed since all logical values are non-null.
213        let result = dict
214            .into_array()
215            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
216        assert!(
217            result.is_ok(),
218            "cast to NonNullable should succeed for dict with only unreferenced null values"
219        );
220        let casted = result.unwrap();
221        assert_eq!(
222            casted.dtype(),
223            &DType::Primitive(PType::F64, Nullability::NonNullable)
224        );
225        #[expect(deprecated)]
226        let casted_prim = casted.to_primitive();
227        assert_arrays_eq!(casted_prim, PrimitiveArray::from_iter([1.0f64, 3.0, 1.0]));
228    }
229}