1use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::register_kernel;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11
12use crate::PcoArray;
13use crate::PcoVTable;
14
15impl CastKernel for PcoVTable {
16 fn cast(&self, array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 if !dtype.is_nullable() || !array.all_valid() {
18 return Ok(None);
22 }
23 if array.dtype().eq_ignore_nullability(dtype) {
28 let new_validity = array
30 .unsliced_validity
31 .clone()
32 .cast_nullability(dtype.nullability(), array.len())?;
33
34 return Ok(Some(
35 PcoArray::new(
36 array.chunk_metas.clone(),
37 array.pages.clone(),
38 dtype.clone(),
39 array.metadata.clone(),
40 array.unsliced_n_rows(),
41 new_validity,
42 )
43 ._slice(array.slice_start(), array.slice_stop())
44 .into_array(),
45 ));
46 }
47
48 Ok(None)
50 }
51}
52
53register_kernel!(CastKernelAdapter(PcoVTable).lift());
54
55#[cfg(test)]
56mod tests {
57 use rstest::rstest;
58 use vortex_array::ToCanonical;
59 use vortex_array::arrays::PrimitiveArray;
60 use vortex_array::assert_arrays_eq;
61 use vortex_array::compute::cast;
62 use vortex_array::compute::conformance::cast::test_cast_conformance;
63 use vortex_array::validity::Validity;
64 use vortex_buffer::Buffer;
65 use vortex_dtype::DType;
66 use vortex_dtype::Nullability;
67 use vortex_dtype::PType;
68
69 use crate::PcoArray;
70
71 #[test]
72 fn test_cast_pco_f32_to_f64() {
73 let values = PrimitiveArray::new(
74 Buffer::copy_from(vec![1.0f32, 2.0, 3.0, 4.0, 5.0]),
75 Validity::NonNullable,
76 );
77 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
78
79 let casted = cast(
80 pco.as_ref(),
81 &DType::Primitive(PType::F64, Nullability::NonNullable),
82 )
83 .unwrap();
84 assert_eq!(
85 casted.dtype(),
86 &DType::Primitive(PType::F64, Nullability::NonNullable)
87 );
88
89 let decoded = casted.to_primitive();
90 let f64_values = decoded.as_slice::<f64>();
91 assert_eq!(f64_values.len(), 5);
92 assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
93 }
94
95 #[test]
96 fn test_cast_pco_nullability_change() {
97 let values = PrimitiveArray::new(
99 Buffer::copy_from(vec![10u32, 20, 30, 40]),
100 Validity::NonNullable,
101 );
102 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
103
104 let casted = cast(
105 pco.as_ref(),
106 &DType::Primitive(PType::U32, Nullability::Nullable),
107 )
108 .unwrap();
109 assert_eq!(
110 casted.dtype(),
111 &DType::Primitive(PType::U32, Nullability::Nullable)
112 );
113 }
114
115 #[test]
116 fn test_cast_sliced_pco_nullable_to_nonnullable() {
117 let values = PrimitiveArray::new(
118 Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
119 Validity::from_iter([true, true, true, true, true, true]),
120 );
121 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
122 let sliced = pco.slice(1..5);
123 let casted = cast(
124 sliced.as_ref(),
125 &DType::Primitive(PType::U32, Nullability::NonNullable),
126 )
127 .unwrap();
128 assert_eq!(
129 casted.dtype(),
130 &DType::Primitive(PType::U32, Nullability::NonNullable)
131 );
132 let decoded = casted.to_primitive();
134 let u32_values = decoded.as_slice::<u32>();
135 assert_eq!(u32_values, &[20, 30, 40, 50]);
136 }
137
138 #[test]
139 fn test_cast_sliced_pco_part_valid_to_nonnullable() {
140 let values = PrimitiveArray::from_option_iter([
141 None,
142 Some(20u32),
143 Some(30),
144 Some(40),
145 Some(50),
146 Some(60),
147 ]);
148 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
149 let sliced = pco.slice(1..5);
150 let casted = cast(
151 sliced.as_ref(),
152 &DType::Primitive(PType::U32, Nullability::NonNullable),
153 )
154 .unwrap();
155 assert_eq!(
156 casted.dtype(),
157 &DType::Primitive(PType::U32, Nullability::NonNullable)
158 );
159 let decoded = casted.to_primitive();
160 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
161 assert_arrays_eq!(decoded, expected);
162 }
163
164 #[rstest]
165 #[case::f32(PrimitiveArray::new(
166 Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),
167 Validity::NonNullable,
168 ))]
169 #[case::f64(PrimitiveArray::new(
170 Buffer::copy_from(vec![100.1f64, 200.2, 300.3, 400.4, 500.5]),
171 Validity::NonNullable,
172 ))]
173 #[case::i32(PrimitiveArray::new(
174 Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
175 Validity::NonNullable,
176 ))]
177 #[case::u64(PrimitiveArray::new(
178 Buffer::copy_from(vec![1000u64, 2000, 3000, 4000]),
179 Validity::NonNullable,
180 ))]
181 #[case::single(PrimitiveArray::new(
182 Buffer::copy_from(vec![42.42f64]),
183 Validity::NonNullable,
184 ))]
185 fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
186 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
187 test_cast_conformance(pco.as_ref());
188 }
189}