1use vortex_array::ArrayRef;
5use vortex_array::dtype::DType;
6use vortex_array::dtype::Nullability;
7use vortex_array::scalar_fn::fns::cast::CastReduce;
8use vortex_error::VortexResult;
9
10use crate::ZstdArray;
11use crate::ZstdVTable;
12
13impl CastReduce for ZstdVTable {
14 fn cast(array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15 if !dtype.eq_ignore_nullability(array.dtype()) {
16 return Ok(None);
19 }
20
21 let src_nullability = array.dtype().nullability();
22 let target_nullability = dtype.nullability();
23
24 match (src_nullability, target_nullability) {
25 (Nullability::Nullable, Nullability::Nullable)
28 | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())),
29 (Nullability::NonNullable, Nullability::Nullable) => {
30 Ok(Some(
32 ZstdArray::new(
33 array.dictionary.clone(),
34 array.frames.clone(),
35 dtype.clone(),
36 array.metadata.clone(),
37 array.unsliced_n_rows(),
38 array.unsliced_validity.clone(),
39 )
40 .slice(array.slice_start()..array.slice_stop())?,
41 ))
42 }
43 (Nullability::Nullable, Nullability::NonNullable) => {
44 let sliced_len = array.slice_stop() - array.slice_start();
46 let has_nulls = !array
47 .unsliced_validity
48 .slice(array.slice_start()..array.slice_stop())?
49 .all_valid(sliced_len)?;
50
51 if has_nulls {
53 return Ok(None);
54 }
55
56 Ok(Some(
58 ZstdArray::new(
59 array.dictionary.clone(),
60 array.frames.clone(),
61 dtype.clone(),
62 array.metadata.clone(),
63 array.unsliced_n_rows(),
64 array.unsliced_validity.clone(),
65 )
66 .slice(array.slice_start()..array.slice_stop())?,
67 ))
68 }
69 }
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use rstest::rstest;
76 use vortex_array::ToCanonical;
77 use vortex_array::arrays::PrimitiveArray;
78 use vortex_array::assert_arrays_eq;
79 use vortex_array::builtins::ArrayBuiltins;
80 use vortex_array::compute::conformance::cast::test_cast_conformance;
81 use vortex_array::dtype::DType;
82 use vortex_array::dtype::Nullability;
83 use vortex_array::dtype::PType;
84 use vortex_array::validity::Validity;
85 use vortex_buffer::Buffer;
86
87 use crate::ZstdArray;
88
89 #[test]
90 fn test_cast_zstd_i32_to_i64() {
91 let values = PrimitiveArray::new(
92 Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
93 Validity::NonNullable,
94 );
95 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
96
97 let casted = zstd
98 .to_array()
99 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
100 .unwrap();
101 assert_eq!(
102 casted.dtype(),
103 &DType::Primitive(PType::I64, Nullability::NonNullable)
104 );
105
106 let decoded = casted.to_primitive();
107 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
108 }
109
110 #[test]
111 fn test_cast_zstd_nullability_change() {
112 let values = PrimitiveArray::new(
113 Buffer::copy_from(vec![10u32, 20, 30, 40]),
114 Validity::NonNullable,
115 );
116 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
117
118 let casted = zstd
119 .to_array()
120 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
121 .unwrap();
122 assert_eq!(
123 casted.dtype(),
124 &DType::Primitive(PType::U32, Nullability::Nullable)
125 );
126 }
127
128 #[test]
129 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
130 let values = PrimitiveArray::new(
131 Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
132 Validity::from_iter([true, true, true, true, true, true]),
133 );
134 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
135 let sliced = zstd.slice(1..5).unwrap();
136 let casted = sliced
137 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
138 .unwrap();
139 assert_eq!(
140 casted.dtype(),
141 &DType::Primitive(PType::U32, Nullability::NonNullable)
142 );
143 let decoded = casted.to_primitive();
145 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
146 }
147
148 #[test]
149 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
150 let values = PrimitiveArray::from_option_iter([
151 None,
152 Some(20u32),
153 Some(30),
154 Some(40),
155 Some(50),
156 Some(60),
157 ]);
158 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
159 let sliced = zstd.slice(1..5).unwrap();
160 let casted = sliced
161 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
162 .unwrap();
163 assert_eq!(
164 casted.dtype(),
165 &DType::Primitive(PType::U32, Nullability::NonNullable)
166 );
167 let decoded = casted.to_primitive();
168 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
169 assert_arrays_eq!(decoded, expected);
170 }
171
172 #[rstest]
173 #[case::i32(PrimitiveArray::new(
174 Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
175 Validity::NonNullable,
176 ))]
177 #[case::f64(PrimitiveArray::new(
178 Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
179 Validity::NonNullable,
180 ))]
181 #[case::single(PrimitiveArray::new(
182 Buffer::copy_from(vec![42i64]),
183 Validity::NonNullable,
184 ))]
185 #[case::large(PrimitiveArray::new(
186 Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
187 Validity::NonNullable,
188 ))]
189 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
190 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
191 test_cast_conformance(zstd.as_ref());
192 }
193}