1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::dtype::Nullability;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_array::validity::Validity;
11use vortex_array::vtable::child_to_validity;
12use vortex_error::VortexResult;
13
14use crate::Zstd;
15use crate::ZstdData;
16impl CastReduce for Zstd {
17 fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18 if !dtype.eq_ignore_nullability(array.dtype()) {
19 return Ok(None);
22 }
23
24 let src_nullability = array.dtype().nullability();
25 let target_nullability = dtype.nullability();
26
27 match (src_nullability, target_nullability) {
28 (Nullability::Nullable, Nullability::Nullable)
31 | (Nullability::NonNullable, Nullability::NonNullable) => {
32 Ok(Some(array.array().clone()))
33 }
34 (Nullability::NonNullable, Nullability::Nullable) => {
35 let unsliced_validity =
37 child_to_validity(&array.slots()[0], array.dtype().nullability());
38 Ok(Some(
39 Zstd::try_new(
40 dtype.clone(),
41 ZstdData::new(
42 array.dictionary.clone(),
43 array.frames.clone(),
44 array.metadata.clone(),
45 array.unsliced_n_rows(),
46 ),
47 unsliced_validity,
48 )?
49 .into_array()
50 .slice(array.slice_start()..array.slice_stop())?,
51 ))
52 }
53 (Nullability::Nullable, Nullability::NonNullable) => {
54 let unsliced_validity =
56 child_to_validity(&array.slots()[0], array.dtype().nullability());
57 let has_nulls = !matches!(
58 unsliced_validity.slice(array.slice_start()..array.slice_stop())?,
59 Validity::AllValid | Validity::NonNullable
60 );
61
62 if has_nulls {
64 return Ok(None);
65 }
66
67 Ok(Some(
69 Zstd::try_new(
70 dtype.clone(),
71 ZstdData::new(
72 array.dictionary.clone(),
73 array.frames.clone(),
74 array.metadata.clone(),
75 array.unsliced_n_rows(),
76 ),
77 unsliced_validity,
78 )?
79 .into_array()
80 .slice(array.slice_start()..array.slice_stop())?,
81 ))
82 }
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use rstest::rstest;
90 use vortex_array::IntoArray;
91 use vortex_array::ToCanonical;
92 use vortex_array::arrays::PrimitiveArray;
93 use vortex_array::assert_arrays_eq;
94 use vortex_array::builtins::ArrayBuiltins;
95 use vortex_array::compute::conformance::cast::test_cast_conformance;
96 use vortex_array::dtype::DType;
97 use vortex_array::dtype::Nullability;
98 use vortex_array::dtype::PType;
99 use vortex_array::validity::Validity;
100 use vortex_buffer::buffer;
101
102 use crate::Zstd;
103
104 #[test]
105 fn test_cast_zstd_i32_to_i64() {
106 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
107 let zstd = Zstd::from_primitive(&values, 0, 0).unwrap();
108
109 let casted = zstd
110 .into_array()
111 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
112 .unwrap();
113 assert_eq!(
114 casted.dtype(),
115 &DType::Primitive(PType::I64, Nullability::NonNullable)
116 );
117
118 let decoded = casted.to_primitive();
119 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
120 }
121
122 #[test]
123 fn test_cast_zstd_nullability_change() {
124 let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
125 let zstd = Zstd::from_primitive(&values, 0, 0).unwrap();
126
127 let casted = zstd
128 .into_array()
129 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
130 .unwrap();
131 assert_eq!(
132 casted.dtype(),
133 &DType::Primitive(PType::U32, Nullability::Nullable)
134 );
135 }
136
137 #[test]
138 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
139 let values = PrimitiveArray::new(
140 buffer![10u32, 20, 30, 40, 50, 60],
141 Validity::from_iter([true, true, true, true, true, true]),
142 );
143 let zstd = Zstd::from_primitive(&values, 0, 128).unwrap();
144 let sliced = zstd.slice(1..5).unwrap();
145 let casted = sliced
146 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
147 .unwrap();
148 assert_eq!(
149 casted.dtype(),
150 &DType::Primitive(PType::U32, Nullability::NonNullable)
151 );
152 let decoded = casted.to_primitive();
154 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
155 }
156
157 #[test]
158 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
159 let values = PrimitiveArray::from_option_iter([
160 None,
161 Some(20u32),
162 Some(30),
163 Some(40),
164 Some(50),
165 Some(60),
166 ]);
167 let zstd = Zstd::from_primitive(&values, 0, 128).unwrap();
168 let sliced = zstd.slice(1..5).unwrap();
169 let casted = sliced
170 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
171 .unwrap();
172 assert_eq!(
173 casted.dtype(),
174 &DType::Primitive(PType::U32, Nullability::NonNullable)
175 );
176 let decoded = casted.to_primitive();
177 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
178 assert_arrays_eq!(decoded, expected);
179 }
180
181 #[rstest]
182 #[case::i32(PrimitiveArray::new(
183 buffer![100i32, 200, 300, 400, 500],
184 Validity::NonNullable,
185 ))]
186 #[case::f64(PrimitiveArray::new(
187 buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
188 Validity::NonNullable,
189 ))]
190 #[case::single(PrimitiveArray::new(
191 buffer![42i64],
192 Validity::NonNullable,
193 ))]
194 #[case::large(PrimitiveArray::new(
195 buffer![0u32..1000],
196 Validity::NonNullable,
197 ))]
198 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
199 let zstd = Zstd::from_primitive(&values, 0, 0).unwrap();
200 test_cast_conformance(&zstd.into_array());
201 }
202}