1use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::dtype::DType;
7use vortex_array::dtype::Nullability;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_array::validity::Validity;
10use vortex_error::VortexResult;
11
12use crate::Zstd;
13use crate::ZstdArray;
14
15impl CastReduce for Zstd {
16 fn cast(array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 if !dtype.eq_ignore_nullability(array.dtype()) {
18 return Ok(None);
21 }
22
23 let src_nullability = array.dtype().nullability();
24 let target_nullability = dtype.nullability();
25
26 match (src_nullability, target_nullability) {
27 (Nullability::Nullable, Nullability::Nullable)
30 | (Nullability::NonNullable, Nullability::NonNullable) => {
31 Ok(Some(array.clone().into_array()))
32 }
33 (Nullability::NonNullable, Nullability::Nullable) => {
34 Ok(Some(
36 ZstdArray::new(
37 array.dictionary.clone(),
38 array.frames.clone(),
39 dtype.clone(),
40 array.metadata.clone(),
41 array.unsliced_n_rows(),
42 array.unsliced_validity.clone(),
43 )
44 .slice(array.slice_start()..array.slice_stop())?,
45 ))
46 }
47 (Nullability::Nullable, Nullability::NonNullable) => {
48 let has_nulls = !matches!(
50 array.validity()?,
51 Validity::AllValid | Validity::NonNullable
52 );
53
54 if has_nulls {
56 return Ok(None);
57 }
58
59 Ok(Some(
61 ZstdArray::new(
62 array.dictionary.clone(),
63 array.frames.clone(),
64 dtype.clone(),
65 array.metadata.clone(),
66 array.unsliced_n_rows(),
67 array.unsliced_validity.clone(),
68 )
69 .slice(array.slice_start()..array.slice_stop())?,
70 ))
71 }
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use rstest::rstest;
79 use vortex_array::IntoArray;
80 use vortex_array::ToCanonical;
81 use vortex_array::arrays::PrimitiveArray;
82 use vortex_array::assert_arrays_eq;
83 use vortex_array::builtins::ArrayBuiltins;
84 use vortex_array::compute::conformance::cast::test_cast_conformance;
85 use vortex_array::dtype::DType;
86 use vortex_array::dtype::Nullability;
87 use vortex_array::dtype::PType;
88 use vortex_array::validity::Validity;
89 use vortex_buffer::buffer;
90
91 use crate::ZstdArray;
92
93 #[test]
94 fn test_cast_zstd_i32_to_i64() {
95 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
96 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
97
98 let casted = zstd
99 .into_array()
100 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
101 .unwrap();
102 assert_eq!(
103 casted.dtype(),
104 &DType::Primitive(PType::I64, Nullability::NonNullable)
105 );
106
107 let decoded = casted.to_primitive();
108 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
109 }
110
111 #[test]
112 fn test_cast_zstd_nullability_change() {
113 let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
114 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
115
116 let casted = zstd
117 .into_array()
118 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
119 .unwrap();
120 assert_eq!(
121 casted.dtype(),
122 &DType::Primitive(PType::U32, Nullability::Nullable)
123 );
124 }
125
126 #[test]
127 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
128 let values = PrimitiveArray::new(
129 buffer![10u32, 20, 30, 40, 50, 60],
130 Validity::from_iter([true, true, true, true, true, true]),
131 );
132 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
133 let sliced = zstd.slice(1..5).unwrap();
134 let casted = sliced
135 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
136 .unwrap();
137 assert_eq!(
138 casted.dtype(),
139 &DType::Primitive(PType::U32, Nullability::NonNullable)
140 );
141 let decoded = casted.to_primitive();
143 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
144 }
145
146 #[test]
147 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
148 let values = PrimitiveArray::from_option_iter([
149 None,
150 Some(20u32),
151 Some(30),
152 Some(40),
153 Some(50),
154 Some(60),
155 ]);
156 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
157 let sliced = zstd.slice(1..5).unwrap();
158 let casted = sliced
159 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
160 .unwrap();
161 assert_eq!(
162 casted.dtype(),
163 &DType::Primitive(PType::U32, Nullability::NonNullable)
164 );
165 let decoded = casted.to_primitive();
166 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
167 assert_arrays_eq!(decoded, expected);
168 }
169
170 #[rstest]
171 #[case::i32(PrimitiveArray::new(
172 buffer![100i32, 200, 300, 400, 500],
173 Validity::NonNullable,
174 ))]
175 #[case::f64(PrimitiveArray::new(
176 buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
177 Validity::NonNullable,
178 ))]
179 #[case::single(PrimitiveArray::new(
180 buffer![42i64],
181 Validity::NonNullable,
182 ))]
183 #[case::large(PrimitiveArray::new(
184 buffer![0u32..1000],
185 Validity::NonNullable,
186 ))]
187 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
188 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
189 test_cast_conformance(&zstd.into_array());
190 }
191}