vortex_zstd/compute/
cast.rs1use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{ZstdArray, ZstdVTable};
10
11impl CastKernel for ZstdVTable {
12 fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13 if array.dtype().eq_ignore_nullability(dtype) {
18 let new_validity = array
20 .unsliced_validity
21 .clone()
22 .cast_nullability(dtype.nullability(), array.len())?;
23
24 return Ok(Some(
25 ZstdArray::new(
26 array.dictionary.clone(),
27 array.frames.clone(),
28 dtype.clone(),
29 array.metadata.clone(),
30 array.unsliced_n_rows(),
31 new_validity,
32 )
33 ._slice(array.slice_start(), array.slice_stop())
34 .into_array(),
35 ));
36 }
37
38 Ok(None)
40 }
41}
42
43register_kernel!(CastKernelAdapter(ZstdVTable).lift());
44
45#[cfg(test)]
46mod tests {
47 use rstest::rstest;
48 use vortex_array::ToCanonical;
49 use vortex_array::arrays::PrimitiveArray;
50 use vortex_array::compute::cast;
51 use vortex_array::compute::conformance::cast::test_cast_conformance;
52 use vortex_buffer::Buffer;
53 use vortex_dtype::{DType, Nullability, PType};
54
55 use crate::ZstdArray;
56
57 #[test]
58 fn test_cast_zstd_i32_to_i64() {
59 let values = PrimitiveArray::new(
60 Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
61 vortex_array::validity::Validity::NonNullable,
62 );
63 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
64
65 let casted = cast(
66 zstd.as_ref(),
67 &DType::Primitive(PType::I64, Nullability::NonNullable),
68 )
69 .unwrap();
70 assert_eq!(
71 casted.dtype(),
72 &DType::Primitive(PType::I64, Nullability::NonNullable)
73 );
74
75 let decoded = casted.to_primitive();
76 assert_eq!(decoded.as_slice::<i64>(), &[1i64, 2, 3, 4, 5]);
77 }
78
79 #[test]
80 fn test_cast_zstd_nullability_change() {
81 let values = PrimitiveArray::new(
82 Buffer::copy_from(vec![10u32, 20, 30, 40]),
83 vortex_array::validity::Validity::NonNullable,
84 );
85 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
86
87 let casted = cast(
88 zstd.as_ref(),
89 &DType::Primitive(PType::U32, Nullability::Nullable),
90 )
91 .unwrap();
92 assert_eq!(
93 casted.dtype(),
94 &DType::Primitive(PType::U32, Nullability::Nullable)
95 );
96 }
97
98 #[rstest]
99 #[case::i32(PrimitiveArray::new(
100 Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
101 vortex_array::validity::Validity::NonNullable,
102 ))]
103 #[case::f64(PrimitiveArray::new(
104 Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
105 vortex_array::validity::Validity::NonNullable,
106 ))]
107 #[case::single(PrimitiveArray::new(
108 Buffer::copy_from(vec![42i64]),
109 vortex_array::validity::Validity::NonNullable,
110 ))]
111 #[case::large(PrimitiveArray::new(
112 Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
113 vortex_array::validity::Validity::NonNullable,
114 ))]
115 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
116 let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
117 test_cast_conformance(zstd.as_ref());
118 }
119}