vortex_array/compute/arrays/
logical.rs1use std::hash::{Hash, Hasher};
5use std::sync::LazyLock;
6
7use enum_map::{Enum, EnumMap, enum_map};
8use vortex_buffer::ByteBuffer;
9use vortex_compute::logical::{
10 LogicalAnd, LogicalAndKleene, LogicalAndNot, LogicalOr, LogicalOrKleene,
11};
12use vortex_dtype::DType;
13use vortex_error::VortexResult;
14use vortex_vector::bool::BoolVector;
15
16use crate::execution::{BatchKernelRef, BindCtx, kernel};
17use crate::serde::ArrayChildren;
18use crate::stats::{ArrayStats, StatsSetRef};
19use crate::vtable::{
20 ArrayVTable, NotSupported, OperatorVTable, SerdeVTable, VTable, VisitorVTable,
21};
22use crate::{
23 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef,
24 DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, Precision, vtable,
25};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enum)]
29pub enum LogicalOperator {
30 And,
32 AndKleene,
34 Or,
36 OrKleene,
38 AndNot,
40}
41
42vtable!(Logical);
43
44#[derive(Debug, Clone)]
45pub struct LogicalArray {
46 encoding: EncodingRef,
47 lhs: ArrayRef,
48 rhs: ArrayRef,
49 stats: ArrayStats,
50}
51
52impl LogicalArray {
53 pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: LogicalOperator) -> Self {
55 assert_eq!(
56 lhs.len(),
57 rhs.len(),
58 "Logical arrays require lhs and rhs to have the same length"
59 );
60
61 assert!(matches!(lhs.dtype(), DType::Bool(_)));
63 assert_eq!(lhs.dtype(), rhs.dtype());
64
65 Self {
66 encoding: ENCODINGS[operator].clone(),
67 lhs,
68 rhs,
69 stats: ArrayStats::default(),
70 }
71 }
72
73 pub fn operator(&self) -> LogicalOperator {
75 self.encoding.as_::<LogicalVTable>().operator
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct LogicalEncoding {
81 operator: LogicalOperator,
85}
86
87#[allow(clippy::mem_forget)]
88static ENCODINGS: LazyLock<EnumMap<LogicalOperator, EncodingRef>> = LazyLock::new(|| {
89 enum_map! {
90 operator => LogicalEncoding { operator }.to_encoding(),
91 }
92});
93
94impl VTable for LogicalVTable {
95 type Array = LogicalArray;
96 type Encoding = LogicalEncoding;
97 type ArrayVTable = Self;
98 type CanonicalVTable = NotSupported;
99 type OperationsVTable = NotSupported;
100 type ValidityVTable = NotSupported;
101 type VisitorVTable = Self;
102 type ComputeVTable = NotSupported;
103 type EncodeVTable = NotSupported;
104 type SerdeVTable = Self;
105 type OperatorVTable = Self;
106
107 fn id(encoding: &Self::Encoding) -> EncodingId {
108 match encoding.operator {
109 LogicalOperator::And => EncodingId::from("vortex.and"),
110 LogicalOperator::AndKleene => EncodingId::from("vortex.and_kleene"),
111 LogicalOperator::Or => EncodingId::from("vortex.or"),
112 LogicalOperator::OrKleene => EncodingId::from("vortex.or_kleene"),
113 LogicalOperator::AndNot => EncodingId::from("vortex.and_not"),
114 }
115 }
116
117 fn encoding(array: &Self::Array) -> EncodingRef {
118 array.encoding.clone()
119 }
120}
121
122impl ArrayVTable<LogicalVTable> for LogicalVTable {
123 fn len(array: &LogicalArray) -> usize {
124 array.lhs.len()
125 }
126
127 fn dtype(array: &LogicalArray) -> &DType {
128 array.lhs.dtype()
129 }
130
131 fn stats(array: &LogicalArray) -> StatsSetRef<'_> {
132 array.stats.to_ref(array.as_ref())
133 }
134
135 fn array_hash<H: Hasher>(array: &LogicalArray, state: &mut H, precision: Precision) {
136 array.lhs.array_hash(state, precision);
137 array.rhs.array_hash(state, precision);
138 }
139
140 fn array_eq(array: &LogicalArray, other: &LogicalArray, precision: Precision) -> bool {
141 array.lhs.array_eq(&other.lhs, precision) && array.rhs.array_eq(&other.rhs, precision)
142 }
143}
144
145impl VisitorVTable<LogicalVTable> for LogicalVTable {
146 fn visit_buffers(_array: &LogicalArray, _visitor: &mut dyn ArrayBufferVisitor) {
147 }
149
150 fn visit_children(array: &LogicalArray, visitor: &mut dyn ArrayChildVisitor) {
151 visitor.visit_child("lhs", array.lhs.as_ref());
152 visitor.visit_child("rhs", array.rhs.as_ref());
153 }
154}
155
156impl SerdeVTable<LogicalVTable> for LogicalVTable {
157 type Metadata = EmptyMetadata;
158
159 fn metadata(_array: &LogicalArray) -> VortexResult<Option<Self::Metadata>> {
160 Ok(Some(EmptyMetadata))
161 }
162
163 fn build(
164 encoding: &LogicalEncoding,
165 dtype: &DType,
166 len: usize,
167 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
168 buffers: &[ByteBuffer],
169 children: &dyn ArrayChildren,
170 ) -> VortexResult<LogicalArray> {
171 assert!(buffers.is_empty());
172 Ok(LogicalArray::new(
173 children.get(0, dtype, len)?,
174 children.get(1, dtype, len)?,
175 encoding.operator,
176 ))
177 }
178}
179
180impl OperatorVTable<LogicalVTable> for LogicalVTable {
181 fn bind(
182 array: &LogicalArray,
183 selection: Option<&ArrayRef>,
184 ctx: &mut dyn BindCtx,
185 ) -> VortexResult<BatchKernelRef> {
186 let lhs = ctx.bind(&array.lhs, selection)?;
187 let rhs = ctx.bind(&array.rhs, selection)?;
188
189 Ok(match array.operator() {
190 LogicalOperator::And => logical_kernel(lhs, rhs, |l, r| l.and(&r)),
191 LogicalOperator::AndKleene => logical_kernel(lhs, rhs, |l, r| l.and_kleene(&r)),
192 LogicalOperator::Or => logical_kernel(lhs, rhs, |l, r| l.or(&r)),
193 LogicalOperator::OrKleene => logical_kernel(lhs, rhs, |l, r| l.or_kleene(&r)),
194 LogicalOperator::AndNot => logical_kernel(lhs, rhs, |l, r| l.and_not(&r)),
195 })
196 }
197}
198
199fn logical_kernel<O>(lhs: BatchKernelRef, rhs: BatchKernelRef, op: O) -> BatchKernelRef
201where
202 O: Fn(BoolVector, BoolVector) -> BoolVector + Send + 'static,
203{
204 kernel(move || {
205 let lhs = lhs.execute()?.into_bool();
206 let rhs = rhs.execute()?.into_bool();
207 Ok(op(lhs, rhs).into())
208 })
209}
210
211#[cfg(test)]
212mod tests {
213 use vortex_buffer::bitbuffer;
214
215 use crate::compute::arrays::logical::{LogicalArray, LogicalOperator};
216 use crate::{ArrayRef, IntoArray};
217
218 fn and_(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
219 LogicalArray::new(lhs, rhs, LogicalOperator::And).into_array()
220 }
221
222 #[test]
223 fn test_and() {
224 let lhs = bitbuffer![0 1 0].into_array();
225 let rhs = bitbuffer![0 1 1].into_array();
226 let result = and_(lhs, rhs).execute().unwrap().into_bool();
227 assert_eq!(result.bits(), &bitbuffer![0 1 0]);
228 }
229
230 #[test]
231 fn test_and_selected() {
232 let lhs = bitbuffer![0 1 0].into_array();
233 let rhs = bitbuffer![0 1 1].into_array();
234
235 let selection = bitbuffer![0 1 1].into();
236
237 let result = and_(lhs, rhs)
238 .execute_with_selection(&selection)
239 .unwrap()
240 .into_bool();
241 assert_eq!(result.bits(), &bitbuffer![1 0]);
242 }
243}