1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::scalar::Scalar;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_error::VortexResult;
12
13use crate::Sparse;
14
15impl CastReduce for Sparse {
16 fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 let casted_patches = array
18 .patches()
19 .clone()
20 .map_values(|values| values.cast(dtype.clone()))?;
21
22 let casted_fill = if array.patches().num_patches() == array.len() {
23 Scalar::default_value(dtype)
26 } else {
27 array.fill_scalar().cast(dtype)?
28 };
29
30 Ok(Some(
31 Sparse::try_new_from_patches(casted_patches, casted_fill)?.into_array(),
32 ))
33 }
34}
35
36#[cfg(test)]
37mod tests {
38 use std::sync::LazyLock;
39
40 use rstest::rstest;
41 use vortex_array::IntoArray;
42 use vortex_array::VortexSessionExecute;
43 use vortex_array::arrays::PrimitiveArray;
44 use vortex_array::assert_arrays_eq;
45 use vortex_array::builtins::ArrayBuiltins;
46 use vortex_array::compute::conformance::cast::test_cast_conformance;
47 use vortex_array::dtype::DType;
48 use vortex_array::dtype::Nullability;
49 use vortex_array::dtype::PType;
50 use vortex_array::scalar::Scalar;
51 use vortex_array::session::ArraySession;
52 use vortex_buffer::buffer;
53 use vortex_session::VortexSession;
54
55 use crate::Sparse;
56 use crate::SparseArray;
57
58 static SESSION: LazyLock<VortexSession> =
59 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
60
61 #[test]
62 fn test_cast_sparse_i32_to_i64() {
63 let mut ctx = SESSION.create_execution_ctx();
64 let sparse = Sparse::try_new(
65 buffer![2u64, 5, 8].into_array(),
66 buffer![100i32, 200, 300].into_array(),
67 10,
68 Scalar::from(0i32),
69 )
70 .unwrap();
71
72 let casted = sparse
73 .into_array()
74 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
75 .unwrap();
76 assert_eq!(
77 casted.dtype(),
78 &DType::Primitive(PType::I64, Nullability::NonNullable)
79 );
80
81 let expected = PrimitiveArray::from_iter([0i64, 0, 100, 0, 0, 200, 0, 0, 300, 0]);
82 let casted_primitive = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
83 assert_arrays_eq!(casted_primitive, expected);
84 }
85
86 #[test]
87 fn test_cast_sparse_with_null_fill() {
88 let sparse = Sparse::try_new(
89 buffer![1u64, 3, 5].into_array(),
90 PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(),
91 8,
92 Scalar::null_native::<i32>(),
93 )
94 .unwrap();
95
96 let casted = sparse
97 .into_array()
98 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
99 .unwrap();
100 assert_eq!(
101 casted.dtype(),
102 &DType::Primitive(PType::I64, Nullability::Nullable)
103 );
104 }
105
106 #[rstest]
107 #[case(Sparse::try_new(
108 buffer![2u64, 5, 8].into_array(),
109 buffer![100i32, 200, 300].into_array(),
110 10,
111 Scalar::from(0i32)
112 ).unwrap())]
113 #[case(Sparse::try_new(
114 buffer![0u64, 4, 9].into_array(),
115 buffer![1.5f32, 2.5, 3.5].into_array(),
116 10,
117 Scalar::from(0.0f32)
118 ).unwrap())]
119 #[case(Sparse::try_new(
120 buffer![1u64, 3, 7].into_array(),
121 PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
122 10,
123 Scalar::null_native::<i32>()
124 ).unwrap())]
125 #[case(Sparse::try_new(
126 buffer![5u64].into_array(),
127 buffer![42u8].into_array(),
128 10,
129 Scalar::from(0u8)
130 ).unwrap())]
131 fn test_cast_sparse_conformance(#[case] array: SparseArray) {
132 test_cast_conformance(&array.into_array());
133 }
134
135 #[test]
136 fn test_cast_sparse_null_fill_all_patched_to_non_nullable() -> vortex_error::VortexResult<()> {
137 let mut ctx = SESSION.create_execution_ctx();
138 let sparse = Sparse::try_new(
144 buffer![0u64, 1, 2, 3, 4].into_array(),
145 buffer![10u64, 20, 30, 40, 50].into_array(),
146 5,
147 Scalar::null_native::<u64>(),
148 )?;
149
150 let casted = sparse
151 .into_array()
152 .cast(DType::Primitive(PType::U64, Nullability::NonNullable))?;
153
154 assert_eq!(
155 casted.dtype(),
156 &DType::Primitive(PType::U64, Nullability::NonNullable)
157 );
158
159 let expected = PrimitiveArray::from_iter([10u64, 20, 30, 40, 50]);
160 let casted_primitive = casted.execute::<PrimitiveArray>(&mut ctx)?;
161 assert_arrays_eq!(casted_primitive, expected);
162 Ok(())
163 }
164
165 #[test]
166 fn test_fill_null_sparse_with_null_fill() -> vortex_error::VortexResult<()> {
167 let sparse = Sparse::try_new(
171 buffer![1u64, 3].into_array(),
172 buffer![10u64, 20].into_array(),
173 5,
174 Scalar::null_native::<u64>(),
175 )?;
176
177 let filled = sparse.into_array().fill_null(Scalar::from(0u64))?;
178
179 assert_eq!(
180 filled.dtype(),
181 &DType::Primitive(PType::U64, Nullability::NonNullable)
182 );
183 Ok(())
184 }
185}