vortex_array/arrays/constant/
operator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10
11use crate::arrays::{ConstantArray, ConstantVTable};
12use crate::operator::{Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
13use crate::pipeline::view::ViewMut;
14use crate::pipeline::{BindContext, Element, Kernel, KernelContext, N, PipelinedOperator};
15use crate::vtable::PipelineVTable;
16
17impl PipelineVTable<ConstantVTable> for ConstantVTable {
18    fn to_operator(array: &ConstantArray) -> VortexResult<Option<OperatorRef>> {
19        Ok(Some(Arc::new(array.clone())))
20    }
21}
22
23impl OperatorHash for ConstantArray {
24    fn operator_hash<H: Hasher>(&self, state: &mut H) {
25        self.scalar.hash(state);
26        self.len.hash(state);
27    }
28}
29
30impl OperatorEq for ConstantArray {
31    fn operator_eq(&self, other: &Self) -> bool {
32        self.scalar == other.scalar && self.len == other.len
33    }
34}
35
36impl Operator for ConstantArray {
37    fn id(&self) -> OperatorId {
38        self.encoding_id()
39    }
40
41    fn as_any(&self) -> &dyn Any {
42        self
43    }
44
45    fn dtype(&self) -> &DType {
46        self.scalar.dtype()
47    }
48
49    fn len(&self) -> usize {
50        self.len
51    }
52
53    fn children(&self) -> &[OperatorRef] {
54        &[]
55    }
56
57    fn with_children(self: Arc<Self>, _children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
58        Ok(self)
59    }
60}
61
62impl PipelinedOperator for ConstantArray {
63    fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
64        debug_assert!(matches!(
65            self.dtype(),
66            DType::Bool(_) | DType::Primitive(..)
67        ));
68        match self.scalar.dtype() {
69            DType::Bool(_) => Ok(Box::new(BoolConstantKernel {
70                remaining: self.len,
71                value: self
72                    .scalar
73                    .as_bool()
74                    .value()
75                    .vortex_expect("scalar value not bool"),
76            })),
77            DType::Primitive(..) => Ok(match_each_native_ptype!(
78                self.scalar.as_primitive().ptype(),
79                |T| {
80                    Box::new(ConstantKernel::<T> {
81                        remaining: self.len,
82                        value: self
83                            .scalar
84                            .as_primitive()
85                            .typed_value::<T>()
86                            .vortex_expect("scalar value not of type T"),
87                    })
88                }
89            )),
90            _ => todo!(
91                "Unsupported scalar type for constant: {:?}",
92                self.scalar.dtype()
93            ),
94        }
95    }
96
97    fn vector_children(&self) -> Vec<usize> {
98        vec![]
99    }
100
101    fn batch_children(&self) -> Vec<usize> {
102        vec![]
103    }
104}
105
106/// Kernel that produces constant primitive values.
107pub struct ConstantKernel<T: NativePType> {
108    remaining: usize,
109    value: T,
110}
111
112/// Kernel that produces constant boolean values.
113pub struct BoolConstantKernel {
114    remaining: usize,
115    value: bool,
116}
117
118impl<T: Element + NativePType> Kernel for ConstantKernel<T> {
119    fn step(&mut self, _ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
120        out.as_slice_mut::<T>()[..N].fill(self.value);
121        let len = self.remaining.min(N);
122        out.set_len(len);
123        Ok(())
124    }
125}
126
127impl Kernel for BoolConstantKernel {
128    fn step(&mut self, _ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
129        out.as_slice_mut::<bool>()[..N].fill(self.value);
130        let len = self.remaining.min(N);
131        out.set_len(len);
132        Ok(())
133    }
134}