vortex_array/arrays/constant/
operator.rs1use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10
11use crate::arrays::{ConstantArray, ConstantVTable};
12use crate::operator::{Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
13use crate::pipeline::view::ViewMut;
14use crate::pipeline::{BindContext, Element, Kernel, KernelContext, N, PipelinedOperator};
15use crate::vtable::PipelineVTable;
16
17impl PipelineVTable<ConstantVTable> for ConstantVTable {
18 fn to_operator(array: &ConstantArray) -> VortexResult<Option<OperatorRef>> {
19 Ok(Some(Arc::new(array.clone())))
20 }
21}
22
23impl OperatorHash for ConstantArray {
24 fn operator_hash<H: Hasher>(&self, state: &mut H) {
25 self.scalar.hash(state);
26 self.len.hash(state);
27 }
28}
29
30impl OperatorEq for ConstantArray {
31 fn operator_eq(&self, other: &Self) -> bool {
32 self.scalar == other.scalar && self.len == other.len
33 }
34}
35
36impl Operator for ConstantArray {
37 fn id(&self) -> OperatorId {
38 self.encoding_id()
39 }
40
41 fn as_any(&self) -> &dyn Any {
42 self
43 }
44
45 fn dtype(&self) -> &DType {
46 self.scalar.dtype()
47 }
48
49 fn len(&self) -> usize {
50 self.len
51 }
52
53 fn children(&self) -> &[OperatorRef] {
54 &[]
55 }
56
57 fn with_children(self: Arc<Self>, _children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
58 Ok(self)
59 }
60}
61
62impl PipelinedOperator for ConstantArray {
63 fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
64 debug_assert!(matches!(
65 self.dtype(),
66 DType::Bool(_) | DType::Primitive(..)
67 ));
68 match self.scalar.dtype() {
69 DType::Bool(_) => Ok(Box::new(BoolConstantKernel {
70 remaining: self.len,
71 value: self
72 .scalar
73 .as_bool()
74 .value()
75 .vortex_expect("scalar value not bool"),
76 })),
77 DType::Primitive(..) => Ok(match_each_native_ptype!(
78 self.scalar.as_primitive().ptype(),
79 |T| {
80 Box::new(ConstantKernel::<T> {
81 remaining: self.len,
82 value: self
83 .scalar
84 .as_primitive()
85 .typed_value::<T>()
86 .vortex_expect("scalar value not of type T"),
87 })
88 }
89 )),
90 _ => todo!(
91 "Unsupported scalar type for constant: {:?}",
92 self.scalar.dtype()
93 ),
94 }
95 }
96
97 fn vector_children(&self) -> Vec<usize> {
98 vec![]
99 }
100
101 fn batch_children(&self) -> Vec<usize> {
102 vec![]
103 }
104}
105
106pub struct ConstantKernel<T: NativePType> {
108 remaining: usize,
109 value: T,
110}
111
112pub struct BoolConstantKernel {
114 remaining: usize,
115 value: bool,
116}
117
118impl<T: Element + NativePType> Kernel for ConstantKernel<T> {
119 fn step(&mut self, _ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
120 out.as_slice_mut::<T>()[..N].fill(self.value);
121 let len = self.remaining.min(N);
122 out.set_len(len);
123 Ok(())
124 }
125}
126
127impl Kernel for BoolConstantKernel {
128 fn step(&mut self, _ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
129 out.as_slice_mut::<bool>()[..N].fill(self.value);
130 let len = self.remaining.min(N);
131 out.set_len(len);
132 Ok(())
133 }
134}