1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::dtype::Nullability;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_array::vtable::child_to_validity;
11use vortex_error::VortexResult;
12
13use crate::Zstd;
14use crate::ZstdData;
15
16impl CastReduce for Zstd {
17 fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18 if !dtype.eq_ignore_nullability(array.dtype()) {
19 return Ok(None);
22 }
23
24 let src_nullability = array.dtype().nullability();
25 let target_nullability = dtype.nullability();
26
27 let new_validity = match (src_nullability, target_nullability) {
28 (Nullability::Nullable, Nullability::Nullable)
31 | (Nullability::NonNullable, Nullability::NonNullable) => {
32 return Ok(Some(array.array().clone()));
33 }
34 (Nullability::NonNullable, Nullability::Nullable) => {
35 child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability())
37 }
38 (Nullability::Nullable, Nullability::NonNullable) => {
39 let unsliced_validity =
41 child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
42 let has_nulls = !unsliced_validity
43 .slice(array.slice_start()..array.slice_stop())?
44 .no_nulls();
45
46 if has_nulls {
48 return Ok(None);
49 }
50 unsliced_validity
51 }
52 };
53
54 Ok(Some(
56 Zstd::try_new(
57 dtype.clone(),
58 ZstdData::new(
59 array.dictionary.clone(),
60 array.frames.clone(),
61 array.metadata.clone(),
62 array.unsliced_n_rows(),
63 ),
64 new_validity,
65 )?
66 .into_array()
67 .slice(array.slice_start()..array.slice_stop())?,
68 ))
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use std::sync::LazyLock;
75
76 use rstest::rstest;
77 use vortex_array::IntoArray;
78 use vortex_array::VortexSessionExecute;
79 use vortex_array::arrays::PrimitiveArray;
80 use vortex_array::assert_arrays_eq;
81 use vortex_array::builtins::ArrayBuiltins;
82 use vortex_array::compute::conformance::cast::test_cast_conformance;
83 use vortex_array::dtype::DType;
84 use vortex_array::dtype::Nullability;
85 use vortex_array::dtype::PType;
86 use vortex_array::session::ArraySession;
87 use vortex_array::validity::Validity;
88 use vortex_buffer::buffer;
89 use vortex_session::VortexSession;
90
91 use crate::Zstd;
92
93 static SESSION: LazyLock<VortexSession> =
94 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
95
96 #[test]
97 fn test_cast_zstd_i32_to_i64() {
98 let mut ctx = SESSION.create_execution_ctx();
99 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
100 let zstd = Zstd::from_primitive(&values, 0, 0, &mut ctx).unwrap();
101
102 let casted = zstd
103 .into_array()
104 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
105 .unwrap();
106 assert_eq!(
107 casted.dtype(),
108 &DType::Primitive(PType::I64, Nullability::NonNullable)
109 );
110
111 let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
112 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
113 }
114
115 #[test]
116 fn test_cast_zstd_nullability_change() {
117 let mut ctx = SESSION.create_execution_ctx();
118 let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
119 let zstd = Zstd::from_primitive(&values, 0, 0, &mut ctx).unwrap();
120
121 let casted = zstd
122 .into_array()
123 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
124 .unwrap();
125 assert_eq!(
126 casted.dtype(),
127 &DType::Primitive(PType::U32, Nullability::Nullable)
128 );
129 }
130
131 #[test]
132 fn test_cast_sliced_zstd_nullable_to_nonnullable() {
133 let mut ctx = SESSION.create_execution_ctx();
134 let values = PrimitiveArray::new(
135 buffer![10u32, 20, 30, 40, 50, 60],
136 Validity::from_iter([true, true, true, true, true, true]),
137 );
138 let zstd = Zstd::from_primitive(&values, 0, 128, &mut ctx).unwrap();
139 let sliced = zstd.slice(1..5).unwrap();
140 let casted = sliced
141 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
142 .unwrap();
143 assert_eq!(
144 casted.dtype(),
145 &DType::Primitive(PType::U32, Nullability::NonNullable)
146 );
147 let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
149 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
150 }
151
152 #[test]
153 fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
154 let mut ctx = SESSION.create_execution_ctx();
155 let values = PrimitiveArray::from_option_iter([
156 None,
157 Some(20u32),
158 Some(30),
159 Some(40),
160 Some(50),
161 Some(60),
162 ]);
163 let zstd = Zstd::from_primitive(&values, 0, 128, &mut ctx).unwrap();
164 let sliced = zstd.slice(1..5).unwrap();
165 let casted = sliced
166 .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
167 .unwrap();
168 assert_eq!(
169 casted.dtype(),
170 &DType::Primitive(PType::U32, Nullability::NonNullable)
171 );
172 let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
173 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
174 assert_arrays_eq!(decoded, expected);
175 }
176
177 #[rstest]
178 #[case::i32(PrimitiveArray::new(
179 buffer![100i32, 200, 300, 400, 500],
180 Validity::NonNullable,
181 ))]
182 #[case::f64(PrimitiveArray::new(
183 buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
184 Validity::NonNullable,
185 ))]
186 #[case::single(PrimitiveArray::new(
187 buffer![42i64],
188 Validity::NonNullable,
189 ))]
190 #[case::large(PrimitiveArray::new(
191 buffer![0u32..1000],
192 Validity::NonNullable,
193 ))]
194 fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
195 let zstd =
196 Zstd::from_primitive(&values, 0, 0, &mut SESSION.create_execution_ctx()).unwrap();
197 test_cast_conformance(&zstd.into_array());
198 }
199}