1use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{ZstdArray, ZstdVTable};
10
11impl CastKernel for ZstdVTable {
12 fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13 if !dtype.is_nullable() || !array.all_valid() {
14 return Ok(None);
17 }
18 if array.dtype().eq_ignore_nullability(dtype) {
23 let new_validity = array
25 .unsliced_validity
26 .clone()
27 .cast_nullability(dtype.nullability(), array.len())?;
28
29 return Ok(Some(
30 ZstdArray::new(
31 array.dictionary.clone(),
32 array.frames.clone(),
33 dtype.clone(),
34 array.metadata.clone(),
35 array.unsliced_n_rows(),
36 new_validity,
37 )
38 ._slice(array.slice_start(), array.slice_stop())
39 .into_array(),
40 ));
41 }
42
43 Ok(None)
45 }
46}
47
48register_kernel!(CastKernelAdapter(ZstdVTable).lift());
49
50#[cfg(test)]
51mod tests {
52 use rstest::rstest;
53 use vortex_array::arrays::PrimitiveArray;
54 use vortex_array::compute::cast;
55 use vortex_array::compute::conformance::cast::test_cast_conformance;
56 use vortex_array::validity::Validity;
57 use vortex_array::{ToCanonical, assert_arrays_eq};
58 use vortex_buffer::Buffer;
59 use vortex_dtype::{DType, Nullability, PType};
60
61 use crate::ZstdArray;
62
63 #[test]
64 fn test_cast_zstd_i32_to_i64() {
65 let values = PrimitiveArray::new(
66 Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
67 Validity::NonNullable,
68 );
69 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
70
71 let casted = cast(
72 zstd.as_ref(),
73 &DType::Primitive(PType::I64, Nullability::NonNullable),
74 )
75 .unwrap();
76 assert_eq!(
77 casted.dtype(),
78 &DType::Primitive(PType::I64, Nullability::NonNullable)
79 );
80
81 let decoded = casted.to_primitive();
82 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
83 }
84
85 #[test]
86 fn test_cast_zstd_nullability_change() {
87 let values = PrimitiveArray::new(
88 Buffer::copy_from(vec![10u32, 20, 30, 40]),
89 Validity::NonNullable,
90 );
91 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
92
93 let casted = cast(
94 zstd.as_ref(),
95 &DType::Primitive(PType::U32, Nullability::Nullable),
96 )
97 .unwrap();
98 assert_eq!(
99 casted.dtype(),
100 &DType::Primitive(PType::U32, Nullability::Nullable)
101 );
102 }
103
104 #[test]
105 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
106 let values = PrimitiveArray::new(
107 Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
108 Validity::from_iter([true, true, true, true, true, true]),
109 );
110 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
111 let sliced = zstd.slice(1..5);
112 let casted = cast(
113 sliced.as_ref(),
114 &DType::Primitive(PType::U32, Nullability::NonNullable),
115 )
116 .unwrap();
117 assert_eq!(
118 casted.dtype(),
119 &DType::Primitive(PType::U32, Nullability::NonNullable)
120 );
121 let decoded = casted.to_primitive();
123 let u32_values = decoded.as_slice::<u32>();
124 assert_eq!(u32_values, &[20, 30, 40, 50]);
125 }
126
127 #[test]
128 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
129 let values = PrimitiveArray::from_option_iter([
130 None,
131 Some(20u32),
132 Some(30),
133 Some(40),
134 Some(50),
135 Some(60),
136 ]);
137 let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
138 let sliced = zstd.slice(1..5);
139 let casted = cast(
140 sliced.as_ref(),
141 &DType::Primitive(PType::U32, Nullability::NonNullable),
142 )
143 .unwrap();
144 assert_eq!(
145 casted.dtype(),
146 &DType::Primitive(PType::U32, Nullability::NonNullable)
147 );
148 let decoded = casted.to_primitive();
149 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
150 assert_arrays_eq!(decoded, expected);
151 }
152
153 #[rstest]
154 #[case::i32(PrimitiveArray::new(
155 Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
156 Validity::NonNullable,
157 ))]
158 #[case::f64(PrimitiveArray::new(
159 Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
160 Validity::NonNullable,
161 ))]
162 #[case::single(PrimitiveArray::new(
163 Buffer::copy_from(vec![42i64]),
164 Validity::NonNullable,
165 ))]
166 #[case::large(PrimitiveArray::new(
167 Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
168 Validity::NonNullable,
169 ))]
170 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
171 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
172 test_cast_conformance(zstd.as_ref());
173 }
174}