vortex_array/arrays/constant/compute/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::Hash;
6use std::rc::Rc;
7
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{ConstantArray, ConstantVTable};
13use crate::pipeline::bits::BitView;
14use crate::pipeline::operators::{BindContext, Operator};
15use crate::pipeline::view::ViewMut;
16use crate::pipeline::{Element, Kernel, KernelContext, PipelineVTable, VType};
17
18impl PipelineVTable<ConstantVTable> for ConstantVTable {
19    fn to_operator(array: &ConstantArray) -> VortexResult<Option<Rc<dyn Operator>>> {
20        Ok(ConstantOperator::maybe_new(array.scalar.clone())
21            .map(|c| Rc::new(c) as Rc<dyn Operator>))
22    }
23}
24
25/// Pipeline operator for constant arrays that produces the same scalar value for all elements.
26#[derive(Debug, Hash)]
27pub struct ConstantOperator {
28    pub(crate) scalar: Scalar,
29}
30
31impl ConstantOperator {
32    pub fn maybe_new(scalar: Scalar) -> Option<Self> {
33        if scalar.is_null() || !matches!(scalar.dtype(), DType::Bool(_) | DType::Primitive(..)) {
34            None
35        } else {
36            Some(Self { scalar })
37        }
38    }
39
40    pub fn new(scalar: Scalar) -> Self {
41        Self::maybe_new(scalar).vortex_expect("scalar cannot be null")
42    }
43}
44
45impl Operator for ConstantOperator {
46    fn as_any(&self) -> &dyn Any {
47        self
48    }
49
50    fn vtype(&self) -> VType {
51        match self.scalar.dtype() {
52            DType::Bool(_) => VType::Bool,
53            DType::Primitive(p, _) => VType::Primitive(*p),
54            DType::Binary(_) => VType::Binary,
55            _ => todo!(),
56        }
57    }
58
59    fn children(&self) -> &[Rc<dyn Operator>] {
60        &[]
61    }
62
63    fn with_children(&self, _children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
64        Rc::new(ConstantOperator::new(self.scalar.clone()))
65    }
66
67    fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
68        debug_assert!(matches!(self.vtype(), VType::Bool | VType::Primitive(_)));
69        match self.scalar.dtype() {
70            DType::Bool(_) => Ok(Box::new(BoolConstantKernel {
71                value: self
72                    .scalar
73                    .as_bool()
74                    .value()
75                    .vortex_expect("scalar value not bool"),
76            })),
77            DType::Primitive(..) => Ok(match_each_native_ptype!(
78                self.scalar.as_primitive().ptype(),
79                |T| {
80                    Box::new(ConstantKernel::<T> {
81                        value: self
82                            .scalar
83                            .as_primitive()
84                            .typed_value::<T>()
85                            .vortex_expect("scalar value not of type T"),
86                    })
87                }
88            )),
89            _ => todo!(
90                "Unsupported scalar type for constant: {:?}",
91                self.scalar.dtype()
92            ),
93        }
94    }
95}
96
97/// Kernel that produces constant primitive values.
98pub struct ConstantKernel<T: NativePType> {
99    value: T,
100}
101
102/// Kernel that produces constant boolean values.
103pub struct BoolConstantKernel {
104    value: bool,
105}
106
107impl<T: Element + NativePType> Kernel for ConstantKernel<T> {
108    fn step(
109        &mut self,
110        _ctx: &KernelContext,
111        selected: BitView,
112        out: &mut ViewMut,
113    ) -> VortexResult<()> {
114        let out_slice = out.as_slice_mut::<T>();
115        for i in 0..selected.true_count() {
116            out_slice[i] = self.value;
117        }
118        Ok(())
119    }
120}
121
122impl Kernel for BoolConstantKernel {
123    fn step(
124        &mut self,
125        _ctx: &KernelContext,
126        selected: BitView,
127        out: &mut ViewMut,
128    ) -> VortexResult<()> {
129        let out_slice = out.as_slice_mut::<bool>();
130        for i in 0..selected.true_count() {
131            out_slice[i] = self.value;
132        }
133        Ok(())
134    }
135}