vortex_array/arrays/constant/compute/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{ConstantArray, ConstantVTable};
13use crate::pipeline::bits::BitView;
14use crate::pipeline::operators::{BindContext, Operator, OperatorRef};
15use crate::pipeline::view::ViewMut;
16use crate::pipeline::{Element, Kernel, KernelContext, PipelineVTable, VType};
17
18impl PipelineVTable<ConstantVTable> for ConstantVTable {
19    fn to_operator(array: &ConstantArray) -> VortexResult<Option<OperatorRef>> {
20        Ok(ConstantOperator::maybe_new(array.scalar.clone()).map(|c| Arc::new(c) as OperatorRef))
21    }
22}
23
24/// Pipeline operator for constant arrays that produces the same scalar value for all elements.
25#[derive(Debug, Hash)]
26pub struct ConstantOperator {
27    pub(crate) scalar: Scalar,
28}
29
30impl ConstantOperator {
31    pub fn maybe_new(scalar: Scalar) -> Option<Self> {
32        if scalar.is_null() || !matches!(scalar.dtype(), DType::Bool(_) | DType::Primitive(..)) {
33            None
34        } else {
35            Some(Self { scalar })
36        }
37    }
38
39    pub fn new(scalar: Scalar) -> Self {
40        Self::maybe_new(scalar).vortex_expect("scalar cannot be null")
41    }
42}
43
44impl Operator for ConstantOperator {
45    fn as_any(&self) -> &dyn Any {
46        self
47    }
48
49    fn vtype(&self) -> VType {
50        match self.scalar.dtype() {
51            DType::Bool(_) => VType::Bool,
52            DType::Primitive(p, _) => VType::Primitive(*p),
53            DType::Binary(_) => VType::Binary,
54            _ => todo!(),
55        }
56    }
57
58    fn children(&self) -> &[OperatorRef] {
59        &[]
60    }
61
62    fn with_children(&self, _children: Vec<OperatorRef>) -> OperatorRef {
63        Arc::new(ConstantOperator::new(self.scalar.clone()))
64    }
65
66    fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
67        debug_assert!(matches!(self.vtype(), VType::Bool | VType::Primitive(_)));
68        match self.scalar.dtype() {
69            DType::Bool(_) => Ok(Box::new(BoolConstantKernel {
70                value: self
71                    .scalar
72                    .as_bool()
73                    .value()
74                    .vortex_expect("scalar value not bool"),
75            })),
76            DType::Primitive(..) => Ok(match_each_native_ptype!(
77                self.scalar.as_primitive().ptype(),
78                |T| {
79                    Box::new(ConstantKernel::<T> {
80                        value: self
81                            .scalar
82                            .as_primitive()
83                            .typed_value::<T>()
84                            .vortex_expect("scalar value not of type T"),
85                    })
86                }
87            )),
88            _ => todo!(
89                "Unsupported scalar type for constant: {:?}",
90                self.scalar.dtype()
91            ),
92        }
93    }
94}
95
96/// Kernel that produces constant primitive values.
97pub struct ConstantKernel<T: NativePType> {
98    value: T,
99}
100
101/// Kernel that produces constant boolean values.
102pub struct BoolConstantKernel {
103    value: bool,
104}
105
106impl<T: Element + NativePType> Kernel for ConstantKernel<T> {
107    fn step(
108        &mut self,
109        _ctx: &KernelContext,
110        selected: BitView,
111        out: &mut ViewMut,
112    ) -> VortexResult<()> {
113        let out_slice = out.as_slice_mut::<T>();
114        for i in 0..selected.true_count() {
115            out_slice[i] = self.value;
116        }
117        Ok(())
118    }
119}
120
121impl Kernel for BoolConstantKernel {
122    fn step(
123        &mut self,
124        _ctx: &KernelContext,
125        selected: BitView,
126        out: &mut ViewMut,
127    ) -> VortexResult<()> {
128        let out_slice = out.as_slice_mut::<bool>();
129        for i in 0..selected.true_count() {
130            out_slice[i] = self.value;
131        }
132        Ok(())
133    }
134}