1use vortex_array::ArrayRef;
5use vortex_array::compute::CastKernel;
6use vortex_array::compute::CastKernelAdapter;
7use vortex_array::register_kernel;
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_error::VortexResult;
11
12use crate::ZstdArray;
13use crate::ZstdVTable;
14
15impl CastKernel for ZstdVTable {
16 fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 if !dtype.eq_ignore_nullability(array.dtype()) {
18 return Ok(None);
21 }
22
23 let src_nullability = array.dtype().nullability();
24 let target_nullability = dtype.nullability();
25
26 match (src_nullability, target_nullability) {
27 (Nullability::Nullable, Nullability::Nullable)
30 | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())),
31 (Nullability::NonNullable, Nullability::Nullable) => Ok(Some(
32 ZstdArray::new(
34 array.dictionary.clone(),
35 array.frames.clone(),
36 dtype.clone(),
37 array.metadata.clone(),
38 array.unsliced_n_rows(),
39 array.unsliced_validity.clone(),
40 )
41 .slice(array.slice_start()..array.slice_stop()),
42 )),
43 (Nullability::Nullable, Nullability::NonNullable) => {
44 let sliced_len = array.slice_stop() - array.slice_start();
46 let has_nulls = !array
47 .unsliced_validity
48 .slice(array.slice_start()..array.slice_stop())
49 .all_valid(sliced_len);
50
51 if has_nulls {
53 return Ok(None);
54 }
55
56 Ok(Some(
58 ZstdArray::new(
59 array.dictionary.clone(),
60 array.frames.clone(),
61 dtype.clone(),
62 array.metadata.clone(),
63 array.unsliced_n_rows(),
64 array.unsliced_validity.clone(),
65 )
66 .slice(array.slice_start()..array.slice_stop()),
67 ))
68 }
69 }
70 }
71}
72
73register_kernel!(CastKernelAdapter(ZstdVTable).lift());
74
75#[cfg(test)]
76mod tests {
77 use rstest::rstest;
78 use vortex_array::ToCanonical;
79 use vortex_array::arrays::PrimitiveArray;
80 use vortex_array::assert_arrays_eq;
81 use vortex_array::compute::cast;
82 use vortex_array::compute::conformance::cast::test_cast_conformance;
83 use vortex_array::validity::Validity;
84 use vortex_buffer::Buffer;
85 use vortex_dtype::DType;
86 use vortex_dtype::Nullability;
87 use vortex_dtype::PType;
88
89 use crate::ZstdArray;
90
91 #[test]
92 fn test_cast_zstd_i32_to_i64() {
93 let values = PrimitiveArray::new(
94 Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
95 Validity::NonNullable,
96 );
97 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
98
99 let casted = cast(
100 zstd.as_ref(),
101 &DType::Primitive(PType::I64, Nullability::NonNullable),
102 )
103 .unwrap();
104 assert_eq!(
105 casted.dtype(),
106 &DType::Primitive(PType::I64, Nullability::NonNullable)
107 );
108
109 let decoded = casted.to_primitive();
110 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
111 }
112
113 #[test]
114 fn test_cast_zstd_nullability_change() {
115 let values = PrimitiveArray::new(
116 Buffer::copy_from(vec![10u32, 20, 30, 40]),
117 Validity::NonNullable,
118 );
119 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
120
121 let casted = cast(
122 zstd.as_ref(),
123 &DType::Primitive(PType::U32, Nullability::Nullable),
124 )
125 .unwrap();
126 assert_eq!(
127 casted.dtype(),
128 &DType::Primitive(PType::U32, Nullability::Nullable)
129 );
130 }
131
132 #[test]
133 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
134 let values = PrimitiveArray::new(
135 Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
136 Validity::from_iter([true, true, true, true, true, true]),
137 );
138 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
139 let sliced = zstd.slice(1..5);
140 let casted = cast(
141 sliced.as_ref(),
142 &DType::Primitive(PType::U32, Nullability::NonNullable),
143 )
144 .unwrap();
145 assert_eq!(
146 casted.dtype(),
147 &DType::Primitive(PType::U32, Nullability::NonNullable)
148 );
149 let decoded = casted.to_primitive();
151 let u32_values = decoded.as_slice::<u32>();
152 assert_eq!(u32_values, &[20, 30, 40, 50]);
153 }
154
155 #[test]
156 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
157 let values = PrimitiveArray::from_option_iter([
158 None,
159 Some(20u32),
160 Some(30),
161 Some(40),
162 Some(50),
163 Some(60),
164 ]);
165 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
166 let sliced = zstd.slice(1..5);
167 let casted = cast(
168 sliced.as_ref(),
169 &DType::Primitive(PType::U32, Nullability::NonNullable),
170 )
171 .unwrap();
172 assert_eq!(
173 casted.dtype(),
174 &DType::Primitive(PType::U32, Nullability::NonNullable)
175 );
176 let decoded = casted.to_primitive();
177 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
178 assert_arrays_eq!(decoded, expected);
179 }
180
181 #[rstest]
182 #[case::i32(PrimitiveArray::new(
183 Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
184 Validity::NonNullable,
185 ))]
186 #[case::f64(PrimitiveArray::new(
187 Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
188 Validity::NonNullable,
189 ))]
190 #[case::single(PrimitiveArray::new(
191 Buffer::copy_from(vec![42i64]),
192 Validity::NonNullable,
193 ))]
194 #[case::large(PrimitiveArray::new(
195 Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
196 Validity::NonNullable,
197 ))]
198 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
199 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
200 test_cast_conformance(zstd.as_ref());
201 }
202}