1use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{PcoArray, PcoVTable};
10
11impl CastKernel for PcoVTable {
12 fn cast(&self, array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13 if !dtype.is_nullable() || !array.all_valid() {
14 return Ok(None);
18 }
19 if array.dtype().eq_ignore_nullability(dtype) {
24 let new_validity = array
26 .unsliced_validity
27 .clone()
28 .cast_nullability(dtype.nullability(), array.len())?;
29
30 return Ok(Some(
31 PcoArray::new(
32 array.chunk_metas.clone(),
33 array.pages.clone(),
34 dtype.clone(),
35 array.metadata.clone(),
36 array.unsliced_n_rows(),
37 new_validity,
38 )
39 ._slice(array.slice_start(), array.slice_stop())
40 .into_array(),
41 ));
42 }
43
44 Ok(None)
46 }
47}
48
49register_kernel!(CastKernelAdapter(PcoVTable).lift());
50
51#[cfg(test)]
52mod tests {
53 use rstest::rstest;
54 use vortex_array::arrays::PrimitiveArray;
55 use vortex_array::compute::cast;
56 use vortex_array::compute::conformance::cast::test_cast_conformance;
57 use vortex_array::validity::Validity;
58 use vortex_array::{ToCanonical, assert_arrays_eq};
59 use vortex_buffer::Buffer;
60 use vortex_dtype::{DType, Nullability, PType};
61
62 use crate::PcoArray;
63
64 #[test]
65 fn test_cast_pco_f32_to_f64() {
66 let values = PrimitiveArray::new(
67 Buffer::copy_from(vec![1.0f32, 2.0, 3.0, 4.0, 5.0]),
68 Validity::NonNullable,
69 );
70 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
71
72 let casted = cast(
73 pco.as_ref(),
74 &DType::Primitive(PType::F64, Nullability::NonNullable),
75 )
76 .unwrap();
77 assert_eq!(
78 casted.dtype(),
79 &DType::Primitive(PType::F64, Nullability::NonNullable)
80 );
81
82 let decoded = casted.to_primitive();
83 let f64_values = decoded.as_slice::<f64>();
84 assert_eq!(f64_values.len(), 5);
85 assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
86 }
87
88 #[test]
89 fn test_cast_pco_nullability_change() {
90 let values = PrimitiveArray::new(
92 Buffer::copy_from(vec![10u32, 20, 30, 40]),
93 Validity::NonNullable,
94 );
95 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
96
97 let casted = cast(
98 pco.as_ref(),
99 &DType::Primitive(PType::U32, Nullability::Nullable),
100 )
101 .unwrap();
102 assert_eq!(
103 casted.dtype(),
104 &DType::Primitive(PType::U32, Nullability::Nullable)
105 );
106 }
107
108 #[test]
109 fn test_cast_sliced_pco_nullable_to_nonnullable() {
110 let values = PrimitiveArray::new(
111 Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
112 Validity::from_iter([true, true, true, true, true, true]),
113 );
114 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
115 let sliced = pco.slice(1..5);
116 let casted = cast(
117 sliced.as_ref(),
118 &DType::Primitive(PType::U32, Nullability::NonNullable),
119 )
120 .unwrap();
121 assert_eq!(
122 casted.dtype(),
123 &DType::Primitive(PType::U32, Nullability::NonNullable)
124 );
125 let decoded = casted.to_primitive();
127 let u32_values = decoded.as_slice::<u32>();
128 assert_eq!(u32_values, &[20, 30, 40, 50]);
129 }
130
131 #[test]
132 fn test_cast_sliced_pco_part_valid_to_nonnullable() {
133 let values = PrimitiveArray::from_option_iter([
134 None,
135 Some(20u32),
136 Some(30),
137 Some(40),
138 Some(50),
139 Some(60),
140 ]);
141 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
142 let sliced = pco.slice(1..5);
143 let casted = cast(
144 sliced.as_ref(),
145 &DType::Primitive(PType::U32, Nullability::NonNullable),
146 )
147 .unwrap();
148 assert_eq!(
149 casted.dtype(),
150 &DType::Primitive(PType::U32, Nullability::NonNullable)
151 );
152 let decoded = casted.to_primitive();
153 let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
154 assert_arrays_eq!(decoded, expected);
155 }
156
157 #[rstest]
158 #[case::f32(PrimitiveArray::new(
159 Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),
160 Validity::NonNullable,
161 ))]
162 #[case::f64(PrimitiveArray::new(
163 Buffer::copy_from(vec![100.1f64, 200.2, 300.3, 400.4, 500.5]),
164 Validity::NonNullable,
165 ))]
166 #[case::i32(PrimitiveArray::new(
167 Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
168 Validity::NonNullable,
169 ))]
170 #[case::u64(PrimitiveArray::new(
171 Buffer::copy_from(vec![1000u64, 2000, 3000, 4000]),
172 Validity::NonNullable,
173 ))]
174 #[case::single(PrimitiveArray::new(
175 Buffer::copy_from(vec![42.42f64]),
176 Validity::NonNullable,
177 ))]
178 fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
179 let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
180 test_cast_conformance(pco.as_ref());
181 }
182}