vortex_alp/alp/compute/
cast.rs1use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::patches::Patches;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::alp::ALPArray;
13use crate::alp::ALPVTable;
14
15impl CastReduce for ALPVTable {
16 fn cast(array: &ALPArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 if array.dtype().eq_ignore_nullability(dtype) {
19 let new_encoded = array.encoded().cast(
22 array
23 .encoded()
24 .dtype()
25 .with_nullability(dtype.nullability()),
26 )?;
27
28 let new_patches = array
29 .patches()
30 .map(|p| {
31 if p.values().dtype() == dtype {
32 Ok(p.clone())
33 } else {
34 Patches::new(
35 p.array_len(),
36 p.offset(),
37 p.indices().clone(),
38 p.values().cast(dtype.clone())?,
39 p.chunk_offsets().clone(),
40 )
41 }
42 })
43 .transpose()?;
44
45 unsafe {
47 Ok(Some(
48 ALPArray::new_unchecked(
49 new_encoded,
50 array.exponents(),
51 new_patches,
52 dtype.clone(),
53 )
54 .into_array(),
55 ))
56 }
57 } else {
58 Ok(None)
59 }
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use rstest::rstest;
66 use vortex_array::IntoArray;
67 use vortex_array::ToCanonical;
68 use vortex_array::arrays::PrimitiveArray;
69 use vortex_array::assert_arrays_eq;
70 use vortex_array::builtins::ArrayBuiltins;
71 use vortex_array::compute::conformance::cast::test_cast_conformance;
72 use vortex_array::dtype::DType;
73 use vortex_array::dtype::Nullability;
74 use vortex_array::dtype::PType;
75 use vortex_buffer::buffer;
76 use vortex_error::VortexExpect;
77 use vortex_error::VortexResult;
78
79 use crate::alp_encode;
80
81 #[test]
82 fn issue_5766_test_cast_alp_with_patches_to_nullable() -> VortexResult<()> {
83 let values = buffer![1.234f32, f32::NAN, 2.345, f32::INFINITY, 3.456].into_array();
84 let alp = alp_encode(&values.to_primitive(), None)?;
85
86 assert!(
87 alp.patches().is_some(),
88 "Test requires ALP array with patches"
89 );
90
91 let nullable_dtype = DType::Primitive(PType::F32, Nullability::Nullable);
92 let casted = alp.to_array().cast(nullable_dtype.clone())?;
93
94 let expected = values.cast(nullable_dtype)?;
95
96 assert_arrays_eq!(casted.to_canonical()?.into_primitive(), expected);
97
98 Ok(())
99 }
100
101 #[test]
102 fn test_cast_alp_f32_to_f64() -> VortexResult<()> {
103 let values = buffer![1.5f32, 2.5, 3.5, 4.5].into_array();
104 let alp = alp_encode(&values.to_primitive(), None)?;
105
106 let casted = alp
107 .to_array()
108 .cast(DType::Primitive(PType::F64, Nullability::NonNullable))?;
109 assert_eq!(
110 casted.dtype(),
111 &DType::Primitive(PType::F64, Nullability::NonNullable)
112 );
113
114 let decoded = casted.to_canonical()?.into_primitive();
115 let values = decoded.as_slice::<f64>();
116 assert_eq!(values.len(), 4);
117 assert!((values[0] - 1.5).abs() < f64::EPSILON);
118 assert!((values[1] - 2.5).abs() < f64::EPSILON);
119
120 Ok(())
121 }
122
123 #[test]
124 fn test_cast_alp_to_int() -> VortexResult<()> {
125 let values = buffer![1.0f32, 2.0, 3.0, 4.0].into_array();
126 let alp = alp_encode(&values.to_primitive(), None)?;
127
128 let casted = alp
129 .to_array()
130 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))?;
131 assert_eq!(
132 casted.dtype(),
133 &DType::Primitive(PType::I32, Nullability::NonNullable)
134 );
135
136 let decoded = casted.to_canonical()?.into_primitive();
137 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i32, 2, 3, 4]));
138
139 Ok(())
140 }
141
142 #[rstest]
143 #[case(buffer![1.23f32, 4.56, 7.89, 10.11, 12.13].into_array())]
144 #[case(buffer![100.1f64, 200.2, 300.3, 400.4, 500.5].into_array())]
145 #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
146 #[case(buffer![42.42f64].into_array())]
147 #[case(buffer![0.0f32, -1.5, 2.5, -3.5, 4.5].into_array())]
148 fn test_cast_alp_conformance(#[case] array: vortex_array::ArrayRef) -> VortexResult<()> {
149 let alp = alp_encode(&array.to_primitive(), None).vortex_expect("cannot fail");
150 test_cast_conformance(&alp.to_array());
151
152 Ok(())
153 }
154}