vortex_array/arrays/constant/compute/
mod.rs

1mod binary_numeric;
2mod boolean;
3mod cast;
4mod compare;
5mod invert;
6mod search_sorted;
7mod take;
8
9use num_traits::{CheckedMul, ToPrimitive};
10use vortex_dtype::{NativePType, PType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_err};
12use vortex_mask::Mask;
13use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
14
15use crate::arrays::ConstantEncoding;
16use crate::arrays::constant::ConstantArray;
17use crate::compute::{
18    BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FilterKernel, FilterKernelAdapter,
19    InvertFn, KernelRef, ScalarAtFn, SearchSortedFn, SliceFn, SumFn, TakeFn, UncompressedSizeFn,
20};
21use crate::stats::Stat;
22use crate::vtable::ComputeVTable;
23use crate::{Array, ArrayComputeImpl, ArrayRef};
24
25impl ArrayComputeImpl for ConstantArray {
26    const FILTER: Option<KernelRef> = FilterKernelAdapter(ConstantEncoding).some();
27}
28
29impl ComputeVTable for ConstantEncoding {
30    fn binary_boolean_fn(&self) -> Option<&dyn BinaryBooleanFn<&dyn Array>> {
31        Some(self)
32    }
33
34    fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
35        Some(self)
36    }
37
38    fn cast_fn(&self) -> Option<&dyn CastFn<&dyn Array>> {
39        Some(self)
40    }
41
42    fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
43        Some(self)
44    }
45
46    fn invert_fn(&self) -> Option<&dyn InvertFn<&dyn Array>> {
47        Some(self)
48    }
49
50    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
51        Some(self)
52    }
53
54    fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
55        Some(self)
56    }
57
58    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
59        Some(self)
60    }
61
62    fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
63        Some(self)
64    }
65
66    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
67        Some(self)
68    }
69
70    fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
71        Some(self)
72    }
73}
74
75impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
76    fn scalar_at(&self, array: &ConstantArray, _index: usize) -> VortexResult<Scalar> {
77        Ok(array.scalar().clone())
78    }
79}
80
81impl SliceFn<&ConstantArray> for ConstantEncoding {
82    fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
83        Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array())
84    }
85}
86
87impl FilterKernel for ConstantEncoding {
88    fn filter(&self, array: &ConstantArray, mask: &Mask) -> VortexResult<ArrayRef> {
89        Ok(ConstantArray::new(array.scalar().clone(), mask.true_count()).into_array())
90    }
91}
92
93impl UncompressedSizeFn<&ConstantArray> for ConstantEncoding {
94    fn uncompressed_size(&self, array: &ConstantArray) -> VortexResult<usize> {
95        let scalar = array.scalar();
96
97        let size = match scalar.as_bool_opt() {
98            Some(_) => array.len() / 8,
99            None => array.scalar().nbytes() * array.len(),
100        };
101        Ok(size)
102    }
103}
104
105impl SumFn<&ConstantArray> for ConstantEncoding {
106    fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
107        let sum_dtype = Stat::Sum
108            .dtype(array.dtype())
109            .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
110        let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
111
112        let scalar = array.scalar();
113
114        let scalar_value = match_each_native_ptype!(
115            sum_ptype,
116            unsigned: |$T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() }
117            signed: |$T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() }
118            floating: |$T| { sum_float(scalar.as_primitive(), array.len())?.into() }
119        );
120
121        Ok(Scalar::new(sum_dtype, scalar_value))
122    }
123}
124
125fn sum_integral<T>(
126    primitive_scalar: PrimitiveScalar<'_>,
127    array_len: usize,
128) -> VortexResult<Option<T>>
129where
130    T: FromPrimitiveOrF16 + NativePType + CheckedMul,
131    Scalar: From<Option<T>>,
132{
133    let v = primitive_scalar.as_::<T>()?;
134    let array_len =
135        T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
136    let sum = v.and_then(|v| v.checked_mul(&array_len));
137
138    Ok(sum)
139}
140
141fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
142    let v = primitive_scalar.as_::<f64>()?;
143    let array_len = array_len
144        .to_f64()
145        .ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
146
147    Ok(v.map(|v| v * array_len))
148}
149
150#[cfg(test)]
151mod test {
152    use vortex_dtype::half::f16;
153    use vortex_scalar::Scalar;
154
155    use super::ConstantArray;
156    use crate::array::Array;
157    use crate::compute::conformance::mask::test_mask;
158
159    #[test]
160    fn test_mask_constant() {
161        test_mask(&ConstantArray::new(Scalar::null_typed::<i32>(), 5).into_array());
162        test_mask(&ConstantArray::new(Scalar::from(3u16), 5).into_array());
163        test_mask(&ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array());
164        test_mask(&ConstantArray::new(Scalar::from(f16::from_f32(3.0f32)), 5).into_array());
165    }
166}