vortex_array/arrays/constant/compute/
pipeline.rs1use std::any::Any;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{ConstantArray, ConstantVTable};
13use crate::pipeline::bits::BitView;
14use crate::pipeline::operators::{BindContext, Operator, OperatorRef};
15use crate::pipeline::view::ViewMut;
16use crate::pipeline::{Element, Kernel, KernelContext, PipelineVTable, VType};
17
18impl PipelineVTable<ConstantVTable> for ConstantVTable {
19 fn to_operator(array: &ConstantArray) -> VortexResult<Option<OperatorRef>> {
20 Ok(ConstantOperator::maybe_new(array.scalar.clone()).map(|c| Arc::new(c) as OperatorRef))
21 }
22}
23
24#[derive(Debug, Hash)]
26pub struct ConstantOperator {
27 pub(crate) scalar: Scalar,
28}
29
30impl ConstantOperator {
31 pub fn maybe_new(scalar: Scalar) -> Option<Self> {
32 if scalar.is_null() || !matches!(scalar.dtype(), DType::Bool(_) | DType::Primitive(..)) {
33 None
34 } else {
35 Some(Self { scalar })
36 }
37 }
38
39 pub fn new(scalar: Scalar) -> Self {
40 Self::maybe_new(scalar).vortex_expect("scalar cannot be null")
41 }
42}
43
44impl Operator for ConstantOperator {
45 fn as_any(&self) -> &dyn Any {
46 self
47 }
48
49 fn vtype(&self) -> VType {
50 match self.scalar.dtype() {
51 DType::Bool(_) => VType::Bool,
52 DType::Primitive(p, _) => VType::Primitive(*p),
53 DType::Binary(_) => VType::Binary,
54 _ => todo!(),
55 }
56 }
57
58 fn children(&self) -> &[OperatorRef] {
59 &[]
60 }
61
62 fn with_children(&self, _children: Vec<OperatorRef>) -> OperatorRef {
63 Arc::new(ConstantOperator::new(self.scalar.clone()))
64 }
65
66 fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
67 debug_assert!(matches!(self.vtype(), VType::Bool | VType::Primitive(_)));
68 match self.scalar.dtype() {
69 DType::Bool(_) => Ok(Box::new(BoolConstantKernel {
70 value: self
71 .scalar
72 .as_bool()
73 .value()
74 .vortex_expect("scalar value not bool"),
75 })),
76 DType::Primitive(..) => Ok(match_each_native_ptype!(
77 self.scalar.as_primitive().ptype(),
78 |T| {
79 Box::new(ConstantKernel::<T> {
80 value: self
81 .scalar
82 .as_primitive()
83 .typed_value::<T>()
84 .vortex_expect("scalar value not of type T"),
85 })
86 }
87 )),
88 _ => todo!(
89 "Unsupported scalar type for constant: {:?}",
90 self.scalar.dtype()
91 ),
92 }
93 }
94}
95
96pub struct ConstantKernel<T: NativePType> {
98 value: T,
99}
100
101pub struct BoolConstantKernel {
103 value: bool,
104}
105
106impl<T: Element + NativePType> Kernel for ConstantKernel<T> {
107 fn step(
108 &mut self,
109 _ctx: &KernelContext,
110 selected: BitView,
111 out: &mut ViewMut,
112 ) -> VortexResult<()> {
113 let out_slice = out.as_slice_mut::<T>();
114 for i in 0..selected.true_count() {
115 out_slice[i] = self.value;
116 }
117 Ok(())
118 }
119}
120
121impl Kernel for BoolConstantKernel {
122 fn step(
123 &mut self,
124 _ctx: &KernelContext,
125 selected: BitView,
126 out: &mut ViewMut,
127 ) -> VortexResult<()> {
128 let out_slice = out.as_slice_mut::<bool>();
129 for i in 0..selected.true_count() {
130 out_slice[i] = self.value;
131 }
132 Ok(())
133 }
134}