vortex_array/arrays/constant/compute/
pipeline.rs1use std::any::Any;
5use std::hash::Hash;
6use std::rc::Rc;
7
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{ConstantArray, ConstantVTable};
13use crate::pipeline::bits::BitView;
14use crate::pipeline::operators::{BindContext, Operator};
15use crate::pipeline::view::ViewMut;
16use crate::pipeline::{Element, Kernel, KernelContext, PipelineVTable, VType};
17
18impl PipelineVTable<ConstantVTable> for ConstantVTable {
19 fn to_operator(array: &ConstantArray) -> VortexResult<Option<Rc<dyn Operator>>> {
20 Ok(ConstantOperator::maybe_new(array.scalar.clone())
21 .map(|c| Rc::new(c) as Rc<dyn Operator>))
22 }
23}
24
25#[derive(Debug, Hash)]
27pub struct ConstantOperator {
28 pub(crate) scalar: Scalar,
29}
30
31impl ConstantOperator {
32 pub fn maybe_new(scalar: Scalar) -> Option<Self> {
33 if scalar.is_null() || !matches!(scalar.dtype(), DType::Bool(_) | DType::Primitive(..)) {
34 None
35 } else {
36 Some(Self { scalar })
37 }
38 }
39
40 pub fn new(scalar: Scalar) -> Self {
41 Self::maybe_new(scalar).vortex_expect("scalar cannot be null")
42 }
43}
44
45impl Operator for ConstantOperator {
46 fn as_any(&self) -> &dyn Any {
47 self
48 }
49
50 fn vtype(&self) -> VType {
51 match self.scalar.dtype() {
52 DType::Bool(_) => VType::Bool,
53 DType::Primitive(p, _) => VType::Primitive(*p),
54 DType::Binary(_) => VType::Binary,
55 _ => todo!(),
56 }
57 }
58
59 fn children(&self) -> &[Rc<dyn Operator>] {
60 &[]
61 }
62
63 fn with_children(&self, _children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
64 Rc::new(ConstantOperator::new(self.scalar.clone()))
65 }
66
67 fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
68 debug_assert!(matches!(self.vtype(), VType::Bool | VType::Primitive(_)));
69 match self.scalar.dtype() {
70 DType::Bool(_) => Ok(Box::new(BoolConstantKernel {
71 value: self
72 .scalar
73 .as_bool()
74 .value()
75 .vortex_expect("scalar value not bool"),
76 })),
77 DType::Primitive(..) => Ok(match_each_native_ptype!(
78 self.scalar.as_primitive().ptype(),
79 |T| {
80 Box::new(ConstantKernel::<T> {
81 value: self
82 .scalar
83 .as_primitive()
84 .typed_value::<T>()
85 .vortex_expect("scalar value not of type T"),
86 })
87 }
88 )),
89 _ => todo!(
90 "Unsupported scalar type for constant: {:?}",
91 self.scalar.dtype()
92 ),
93 }
94 }
95}
96
97pub struct ConstantKernel<T: NativePType> {
99 value: T,
100}
101
102pub struct BoolConstantKernel {
104 value: bool,
105}
106
107impl<T: Element + NativePType> Kernel for ConstantKernel<T> {
108 fn step(
109 &mut self,
110 _ctx: &KernelContext,
111 selected: BitView,
112 out: &mut ViewMut,
113 ) -> VortexResult<()> {
114 let out_slice = out.as_slice_mut::<T>();
115 for i in 0..selected.true_count() {
116 out_slice[i] = self.value;
117 }
118 Ok(())
119 }
120}
121
122impl Kernel for BoolConstantKernel {
123 fn step(
124 &mut self,
125 _ctx: &KernelContext,
126 selected: BitView,
127 out: &mut ViewMut,
128 ) -> VortexResult<()> {
129 let out_slice = out.as_slice_mut::<bool>();
130 for i in 0..selected.true_count() {
131 out_slice[i] = self.value;
132 }
133 Ok(())
134 }
135}