vortex_array/arrays/constant/compute/
mod.rs1mod binary_numeric;
2mod boolean;
3mod cast;
4mod compare;
5mod invert;
6mod search_sorted;
7mod take;
8
9use num_traits::{CheckedMul, ToPrimitive};
10use vortex_dtype::{NativePType, PType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_err};
12use vortex_mask::Mask;
13use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
14
15use crate::arrays::ConstantEncoding;
16use crate::arrays::constant::ConstantArray;
17use crate::compute::{
18 BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FilterKernel, FilterKernelAdapter,
19 InvertFn, KernelRef, ScalarAtFn, SearchSortedFn, SliceFn, SumFn, TakeFn, UncompressedSizeFn,
20};
21use crate::stats::Stat;
22use crate::vtable::ComputeVTable;
23use crate::{Array, ArrayComputeImpl, ArrayRef};
24
25impl ArrayComputeImpl for ConstantArray {
26 const FILTER: Option<KernelRef> = FilterKernelAdapter(ConstantEncoding).some();
27}
28
29impl ComputeVTable for ConstantEncoding {
30 fn binary_boolean_fn(&self) -> Option<&dyn BinaryBooleanFn<&dyn Array>> {
31 Some(self)
32 }
33
34 fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
35 Some(self)
36 }
37
38 fn cast_fn(&self) -> Option<&dyn CastFn<&dyn Array>> {
39 Some(self)
40 }
41
42 fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
43 Some(self)
44 }
45
46 fn invert_fn(&self) -> Option<&dyn InvertFn<&dyn Array>> {
47 Some(self)
48 }
49
50 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
51 Some(self)
52 }
53
54 fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
55 Some(self)
56 }
57
58 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
59 Some(self)
60 }
61
62 fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
63 Some(self)
64 }
65
66 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
67 Some(self)
68 }
69
70 fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
71 Some(self)
72 }
73}
74
75impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
76 fn scalar_at(&self, array: &ConstantArray, _index: usize) -> VortexResult<Scalar> {
77 Ok(array.scalar().clone())
78 }
79}
80
81impl SliceFn<&ConstantArray> for ConstantEncoding {
82 fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
83 Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array())
84 }
85}
86
87impl FilterKernel for ConstantEncoding {
88 fn filter(&self, array: &ConstantArray, mask: &Mask) -> VortexResult<ArrayRef> {
89 Ok(ConstantArray::new(array.scalar().clone(), mask.true_count()).into_array())
90 }
91}
92
93impl UncompressedSizeFn<&ConstantArray> for ConstantEncoding {
94 fn uncompressed_size(&self, array: &ConstantArray) -> VortexResult<usize> {
95 let scalar = array.scalar();
96
97 let size = match scalar.as_bool_opt() {
98 Some(_) => array.len() / 8,
99 None => array.scalar().nbytes() * array.len(),
100 };
101 Ok(size)
102 }
103}
104
105impl SumFn<&ConstantArray> for ConstantEncoding {
106 fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
107 let sum_dtype = Stat::Sum
108 .dtype(array.dtype())
109 .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
110 let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
111
112 let scalar = array.scalar();
113
114 let scalar_value = match_each_native_ptype!(
115 sum_ptype,
116 unsigned: |$T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() }
117 signed: |$T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() }
118 floating: |$T| { sum_float(scalar.as_primitive(), array.len())?.into() }
119 );
120
121 Ok(Scalar::new(sum_dtype, scalar_value))
122 }
123}
124
125fn sum_integral<T>(
126 primitive_scalar: PrimitiveScalar<'_>,
127 array_len: usize,
128) -> VortexResult<Option<T>>
129where
130 T: FromPrimitiveOrF16 + NativePType + CheckedMul,
131 Scalar: From<Option<T>>,
132{
133 let v = primitive_scalar.as_::<T>()?;
134 let array_len =
135 T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
136 let sum = v.and_then(|v| v.checked_mul(&array_len));
137
138 Ok(sum)
139}
140
141fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
142 let v = primitive_scalar.as_::<f64>()?;
143 let array_len = array_len
144 .to_f64()
145 .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
146
147 Ok(v.map(|v| v * array_len))
148}
149
150#[cfg(test)]
151mod test {
152 use vortex_dtype::half::f16;
153 use vortex_scalar::Scalar;
154
155 use super::ConstantArray;
156 use crate::array::Array;
157 use crate::compute::conformance::mask::test_mask;
158
159 #[test]
160 fn test_mask_constant() {
161 test_mask(&ConstantArray::new(Scalar::null_typed::<i32>(), 5).into_array());
162 test_mask(&ConstantArray::new(Scalar::from(3u16), 5).into_array());
163 test_mask(&ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array());
164 test_mask(&ConstantArray::new(Scalar::from(f16::from_f32(3.0f32)), 5).into_array());
165 }
166}