1use vortex_array::ArrayRef;
5use vortex_array::dtype::DType;
6use vortex_array::dtype::Nullability;
7use vortex_array::scalar_fn::fns::cast::CastReduce;
8use vortex_error::VortexResult;
9
10use crate::ZstdArray;
11use crate::ZstdVTable;
12
13impl CastReduce for ZstdVTable {
14 fn cast(array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15 if !dtype.eq_ignore_nullability(array.dtype()) {
16 return Ok(None);
19 }
20
21 let src_nullability = array.dtype().nullability();
22 let target_nullability = dtype.nullability();
23
24 match (src_nullability, target_nullability) {
25 (Nullability::Nullable, Nullability::Nullable)
28 | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())),
29 (Nullability::NonNullable, Nullability::Nullable) => {
30 Ok(Some(
32 ZstdArray::new(
33 array.dictionary.clone(),
34 array.frames.clone(),
35 dtype.clone(),
36 array.metadata.clone(),
37 array.unsliced_n_rows(),
38 array.unsliced_validity.clone(),
39 )
40 .slice(array.slice_start()..array.slice_stop())?,
41 ))
42 }
43 (Nullability::Nullable, Nullability::NonNullable) => {
44 let sliced_len = array.slice_stop() - array.slice_start();
46 let has_nulls = !array
47 .unsliced_validity
48 .slice(array.slice_start()..array.slice_stop())?
49 .all_valid(sliced_len)?;
50
51 if has_nulls {
53 return Ok(None);
54 }
55
56 Ok(Some(
58 ZstdArray::new(
59 array.dictionary.clone(),
60 array.frames.clone(),
61 dtype.clone(),
62 array.metadata.clone(),
63 array.unsliced_n_rows(),
64 array.unsliced_validity.clone(),
65 )
66 .slice(array.slice_start()..array.slice_stop())?,
67 ))
68 }
69 }
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use rstest::rstest;
76 use vortex_array::ToCanonical;
77 use vortex_array::arrays::PrimitiveArray;
78 use vortex_array::assert_arrays_eq;
79 use vortex_array::builtins::ArrayBuiltins;
80 use vortex_array::compute::conformance::cast::test_cast_conformance;
81 use vortex_array::dtype::DType;
82 use vortex_array::dtype::Nullability;
83 use vortex_array::dtype::PType;
84 use vortex_array::validity::Validity;
85 use vortex_buffer::buffer;
86
87 use crate::ZstdArray;
88
89 #[test]
90 fn test_cast_zstd_i32_to_i64() {
91 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
92 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
93
94 let casted = zstd
95 .to_array()
96 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
97 .unwrap();
98 assert_eq!(
99 casted.dtype(),
100 &DType::Primitive(PType::I64, Nullability::NonNullable)
101 );
102
103 let decoded = casted.to_primitive();
104 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
105 }
106
107 #[test]
108 fn test_cast_zstd_nullability_change() {
109 let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
110 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
111
112 let casted = zstd
113 .to_array()
114 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
115 .unwrap();
116 assert_eq!(
117 casted.dtype(),
118 &DType::Primitive(PType::U32, Nullability::Nullable)
119 );
120 }
121
122 #[test]
123 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
124 let values = PrimitiveArray::new(
125 buffer![10u32, 20, 30, 40, 50, 60],
126 Validity::from_iter([true, true, true, true, true, true]),
127 );
128 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
129 let sliced = zstd.slice(1..5).unwrap();
130 let casted = sliced
131 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
132 .unwrap();
133 assert_eq!(
134 casted.dtype(),
135 &DType::Primitive(PType::U32, Nullability::NonNullable)
136 );
137 let decoded = casted.to_primitive();
139 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
140 }
141
142 #[test]
143 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
144 let values = PrimitiveArray::from_option_iter([
145 None,
146 Some(20u32),
147 Some(30),
148 Some(40),
149 Some(50),
150 Some(60),
151 ]);
152 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
153 let sliced = zstd.slice(1..5).unwrap();
154 let casted = sliced
155 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
156 .unwrap();
157 assert_eq!(
158 casted.dtype(),
159 &DType::Primitive(PType::U32, Nullability::NonNullable)
160 );
161 let decoded = casted.to_primitive();
162 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
163 assert_arrays_eq!(decoded, expected);
164 }
165
166 #[rstest]
167 #[case::i32(PrimitiveArray::new(
168 buffer![100i32, 200, 300, 400, 500],
169 Validity::NonNullable,
170 ))]
171 #[case::f64(PrimitiveArray::new(
172 buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
173 Validity::NonNullable,
174 ))]
175 #[case::single(PrimitiveArray::new(
176 buffer![42i64],
177 Validity::NonNullable,
178 ))]
179 #[case::large(PrimitiveArray::new(
180 buffer![0u32..1000],
181 Validity::NonNullable,
182 ))]
183 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
184 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
185 test_cast_conformance(&zstd.to_array());
186 }
187}