vortex_array/compute/arrays/
logical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::{Hash, Hasher};
5use std::sync::LazyLock;
6
7use enum_map::{Enum, EnumMap, enum_map};
8use vortex_buffer::ByteBuffer;
9use vortex_compute::logical::{
10    LogicalAnd, LogicalAndKleene, LogicalAndNot, LogicalOr, LogicalOrKleene,
11};
12use vortex_dtype::DType;
13use vortex_error::VortexResult;
14use vortex_vector::bool::BoolVector;
15
16use crate::execution::{BatchKernelRef, BindCtx, kernel};
17use crate::serde::ArrayChildren;
18use crate::stats::{ArrayStats, StatsSetRef};
19use crate::vtable::{
20    ArrayVTable, NotSupported, OperatorVTable, SerdeVTable, VTable, VisitorVTable,
21};
22use crate::{
23    Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef,
24    DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, Precision, vtable,
25};
26
27/// The set of operators supported by a logical array.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enum)]
29pub enum LogicalOperator {
30    /// Logical AND
31    And,
32    /// Logical AND with Kleene logic
33    AndKleene,
34    /// Logical OR
35    Or,
36    /// Logical OR with Kleene logic
37    OrKleene,
38    /// Logical AND NOT
39    AndNot,
40}
41
42vtable!(Logical);
43
44#[derive(Debug, Clone)]
45pub struct LogicalArray {
46    encoding: EncodingRef,
47    lhs: ArrayRef,
48    rhs: ArrayRef,
49    stats: ArrayStats,
50}
51
52impl LogicalArray {
53    /// Create a new logical array.
54    pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: LogicalOperator) -> Self {
55        assert_eq!(
56            lhs.len(),
57            rhs.len(),
58            "Logical arrays require lhs and rhs to have the same length"
59        );
60
61        // TODO(ngates): should we automatically cast non-null to nullable if required?
62        assert!(matches!(lhs.dtype(), DType::Bool(_)));
63        assert_eq!(lhs.dtype(), rhs.dtype());
64
65        Self {
66            encoding: ENCODINGS[operator].clone(),
67            lhs,
68            rhs,
69            stats: ArrayStats::default(),
70        }
71    }
72
73    /// Returns the operator of this logical array.
74    pub fn operator(&self) -> LogicalOperator {
75        self.encoding.as_::<LogicalVTable>().operator
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct LogicalEncoding {
81    // We include the operator in the encoding so each operator is a different encoding ID.
82    // This makes it easier for plugins to construct expressions and perform pushdown
83    // optimizations.
84    operator: LogicalOperator,
85}
86
87#[allow(clippy::mem_forget)]
88static ENCODINGS: LazyLock<EnumMap<LogicalOperator, EncodingRef>> = LazyLock::new(|| {
89    enum_map! {
90        operator => LogicalEncoding { operator }.to_encoding(),
91    }
92});
93
94impl VTable for LogicalVTable {
95    type Array = LogicalArray;
96    type Encoding = LogicalEncoding;
97    type ArrayVTable = Self;
98    type CanonicalVTable = NotSupported;
99    type OperationsVTable = NotSupported;
100    type ValidityVTable = NotSupported;
101    type VisitorVTable = Self;
102    type ComputeVTable = NotSupported;
103    type EncodeVTable = NotSupported;
104    type SerdeVTable = Self;
105    type OperatorVTable = Self;
106
107    fn id(encoding: &Self::Encoding) -> EncodingId {
108        match encoding.operator {
109            LogicalOperator::And => EncodingId::from("vortex.and"),
110            LogicalOperator::AndKleene => EncodingId::from("vortex.and_kleene"),
111            LogicalOperator::Or => EncodingId::from("vortex.or"),
112            LogicalOperator::OrKleene => EncodingId::from("vortex.or_kleene"),
113            LogicalOperator::AndNot => EncodingId::from("vortex.and_not"),
114        }
115    }
116
117    fn encoding(array: &Self::Array) -> EncodingRef {
118        array.encoding.clone()
119    }
120}
121
122impl ArrayVTable<LogicalVTable> for LogicalVTable {
123    fn len(array: &LogicalArray) -> usize {
124        array.lhs.len()
125    }
126
127    fn dtype(array: &LogicalArray) -> &DType {
128        array.lhs.dtype()
129    }
130
131    fn stats(array: &LogicalArray) -> StatsSetRef<'_> {
132        array.stats.to_ref(array.as_ref())
133    }
134
135    fn array_hash<H: Hasher>(array: &LogicalArray, state: &mut H, precision: Precision) {
136        array.lhs.array_hash(state, precision);
137        array.rhs.array_hash(state, precision);
138    }
139
140    fn array_eq(array: &LogicalArray, other: &LogicalArray, precision: Precision) -> bool {
141        array.lhs.array_eq(&other.lhs, precision) && array.rhs.array_eq(&other.rhs, precision)
142    }
143}
144
145impl VisitorVTable<LogicalVTable> for LogicalVTable {
146    fn visit_buffers(_array: &LogicalArray, _visitor: &mut dyn ArrayBufferVisitor) {
147        // No buffers
148    }
149
150    fn visit_children(array: &LogicalArray, visitor: &mut dyn ArrayChildVisitor) {
151        visitor.visit_child("lhs", array.lhs.as_ref());
152        visitor.visit_child("rhs", array.rhs.as_ref());
153    }
154}
155
156impl SerdeVTable<LogicalVTable> for LogicalVTable {
157    type Metadata = EmptyMetadata;
158
159    fn metadata(_array: &LogicalArray) -> VortexResult<Option<Self::Metadata>> {
160        Ok(Some(EmptyMetadata))
161    }
162
163    fn build(
164        encoding: &LogicalEncoding,
165        dtype: &DType,
166        len: usize,
167        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
168        buffers: &[ByteBuffer],
169        children: &dyn ArrayChildren,
170    ) -> VortexResult<LogicalArray> {
171        assert!(buffers.is_empty());
172        Ok(LogicalArray::new(
173            children.get(0, dtype, len)?,
174            children.get(1, dtype, len)?,
175            encoding.operator,
176        ))
177    }
178}
179
180impl OperatorVTable<LogicalVTable> for LogicalVTable {
181    fn bind(
182        array: &LogicalArray,
183        selection: Option<&ArrayRef>,
184        ctx: &mut dyn BindCtx,
185    ) -> VortexResult<BatchKernelRef> {
186        let lhs = ctx.bind(&array.lhs, selection)?;
187        let rhs = ctx.bind(&array.rhs, selection)?;
188
189        Ok(match array.operator() {
190            LogicalOperator::And => logical_kernel(lhs, rhs, |l, r| l.and(&r)),
191            LogicalOperator::AndKleene => logical_kernel(lhs, rhs, |l, r| l.and_kleene(&r)),
192            LogicalOperator::Or => logical_kernel(lhs, rhs, |l, r| l.or(&r)),
193            LogicalOperator::OrKleene => logical_kernel(lhs, rhs, |l, r| l.or_kleene(&r)),
194            LogicalOperator::AndNot => logical_kernel(lhs, rhs, |l, r| l.and_not(&r)),
195        })
196    }
197}
198
199/// Batch execution kernel for logical operations.
200fn logical_kernel<O>(lhs: BatchKernelRef, rhs: BatchKernelRef, op: O) -> BatchKernelRef
201where
202    O: Fn(BoolVector, BoolVector) -> BoolVector + Send + 'static,
203{
204    kernel(move || {
205        let lhs = lhs.execute()?.into_bool();
206        let rhs = rhs.execute()?.into_bool();
207        Ok(op(lhs, rhs).into())
208    })
209}
210
211#[cfg(test)]
212mod tests {
213    use vortex_buffer::bitbuffer;
214
215    use crate::compute::arrays::logical::{LogicalArray, LogicalOperator};
216    use crate::{ArrayRef, IntoArray};
217
218    fn and_(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
219        LogicalArray::new(lhs, rhs, LogicalOperator::And).into_array()
220    }
221
222    #[test]
223    fn test_and() {
224        let lhs = bitbuffer![0 1 0].into_array();
225        let rhs = bitbuffer![0 1 1].into_array();
226        let result = and_(lhs, rhs).execute().unwrap().into_bool();
227        assert_eq!(result.bits(), &bitbuffer![0 1 0]);
228    }
229
230    #[test]
231    fn test_and_selected() {
232        let lhs = bitbuffer![0 1 0].into_array();
233        let rhs = bitbuffer![0 1 1].into_array();
234
235        let selection = bitbuffer![0 1 1].into();
236
237        let result = and_(lhs, rhs)
238            .execute_with_selection(&selection)
239            .unwrap()
240            .into_bool();
241        assert_eq!(result.bits(), &bitbuffer![1 0]);
242    }
243}