vortex_array/scalar_fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use prost::Message;
5use vortex_compute::arithmetic::Add;
6use vortex_compute::arithmetic::Arithmetic;
7use vortex_compute::arithmetic::CheckedArithmetic;
8use vortex_compute::arithmetic::Div;
9use vortex_compute::arithmetic::Mul;
10use vortex_compute::arithmetic::Sub;
11use vortex_compute::comparison::Compare;
12use vortex_compute::comparison::Equal;
13use vortex_compute::comparison::GreaterThan;
14use vortex_compute::comparison::GreaterThanOrEqual;
15use vortex_compute::comparison::LessThan;
16use vortex_compute::comparison::LessThanOrEqual;
17use vortex_compute::comparison::NotEqual;
18use vortex_compute::logical::KleeneAnd;
19use vortex_compute::logical::KleeneOr;
20use vortex_compute::logical::LogicalOp;
21use vortex_dtype::DType;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_error::vortex_err;
25use vortex_proto::expr as pb;
26use vortex_vector::BoolDatum;
27use vortex_vector::Datum;
28use vortex_vector::PrimitiveDatum;
29
30use crate::expr::ChildName;
31use crate::expr::Operator;
32use crate::expr::functions::ArgName;
33use crate::expr::functions::Arity;
34use crate::expr::functions::ExecutionArgs;
35use crate::expr::functions::FunctionId;
36use crate::expr::functions::NullHandling;
37use crate::expr::functions::VTable;
38
39pub struct BinaryFn;
40impl VTable for BinaryFn {
41    type Options = Operator;
42
43    fn id(&self) -> FunctionId {
44        FunctionId::from("vortex.binary")
45    }
46
47    fn serialize(&self, op: &Operator) -> VortexResult<Option<Vec<u8>>> {
48        Ok(Some(pb::BinaryOpts { op: (*op).into() }.encode_to_vec()))
49    }
50
51    fn deserialize(&self, bytes: &[u8]) -> VortexResult<Operator> {
52        let opts = pb::BinaryOpts::decode(bytes)?;
53        Operator::try_from(opts.op)
54    }
55
56    fn arity(&self, _options: &Operator) -> Arity {
57        Arity::Exact(2)
58    }
59
60    fn null_handling(&self, options: &Operator) -> NullHandling {
61        match options {
62            Operator::And | Operator::Or => NullHandling::AbsorbsNull,
63            _ => NullHandling::Propagate,
64        }
65    }
66
67    fn arg_name(&self, _options: &Operator, arg_idx: usize) -> ArgName {
68        match arg_idx {
69            0 => ChildName::from("lhs"),
70            1 => ChildName::from("rhs"),
71            _ => unreachable!("Binary has only two arguments"),
72        }
73    }
74
75    fn return_dtype(&self, options: &Operator, arg_types: &[DType]) -> VortexResult<DType> {
76        let lhs = &arg_types[0];
77        let rhs = &arg_types[1];
78
79        if options.is_arithmetic() {
80            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
81                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
82            }
83            vortex_bail!(
84                "incompatible types for arithmetic operation: {} {}",
85                lhs,
86                rhs
87            );
88        }
89
90        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
91    }
92
93    fn execute(&self, op: &Operator, args: &ExecutionArgs) -> VortexResult<Datum> {
94        let lhs: Datum = args.input_datums(0).clone();
95        let rhs: Datum = args.input_datums(1).clone();
96
97        match op {
98            Operator::Eq => Ok(Compare::<Equal>::compare(lhs, rhs).into()),
99            Operator::NotEq => Ok(Compare::<NotEqual>::compare(lhs, rhs).into()),
100            Operator::Lt => Ok(Compare::<LessThan>::compare(lhs, rhs).into()),
101            Operator::Lte => Ok(Compare::<LessThanOrEqual>::compare(lhs, rhs).into()),
102            Operator::Gt => Ok(Compare::<GreaterThan>::compare(lhs, rhs).into()),
103            Operator::Gte => Ok(Compare::<GreaterThanOrEqual>::compare(lhs, rhs).into()),
104            Operator::And => Ok(<BoolDatum as LogicalOp<KleeneAnd>>::op(
105                lhs.into_bool(),
106                rhs.into_bool(),
107            )
108            .into()),
109            Operator::Or => {
110                Ok(<BoolDatum as LogicalOp<KleeneOr>>::op(lhs.into_bool(), rhs.into_bool()).into())
111            }
112            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
113                execute_arithmetic_primitive(lhs.into_primitive(), rhs.into_primitive(), *op)
114            }
115        }
116    }
117}
118
119fn execute_arithmetic_primitive(
120    lhs: PrimitiveDatum,
121    rhs: PrimitiveDatum,
122    op: Operator,
123) -> VortexResult<Datum> {
124    // Float arithmetic - no overflow checking needed
125    if lhs.ptype().is_float() && lhs.ptype() == rhs.ptype() {
126        let result: PrimitiveDatum = match op {
127            Operator::Add => Arithmetic::<Add>::eval(lhs, rhs),
128            Operator::Sub => Arithmetic::<Sub>::eval(lhs, rhs),
129            Operator::Mul => Arithmetic::<Mul>::eval(lhs, rhs),
130            Operator::Div => Arithmetic::<Div>::eval(lhs, rhs),
131            _ => unreachable!("Not an arithmetic operator"),
132        };
133        return Ok(result.into());
134    }
135
136    // Integer arithmetic - use checked operations
137    let result: Option<PrimitiveDatum> = match op {
138        Operator::Add => CheckedArithmetic::<Add>::checked_eval(lhs, rhs),
139        Operator::Sub => CheckedArithmetic::<Sub>::checked_eval(lhs, rhs),
140        Operator::Mul => CheckedArithmetic::<Mul>::checked_eval(lhs, rhs),
141        Operator::Div => CheckedArithmetic::<Div>::checked_eval(lhs, rhs),
142        _ => unreachable!("Not an arithmetic operator"),
143    };
144    result
145        .map(|d| d.into())
146        .ok_or_else(|| vortex_err!("Arithmetic overflow/underflow or type mismatch"))
147}
148
149#[cfg(test)]
150mod tests {
151    use vortex_buffer::buffer;
152    use vortex_dtype::DType;
153    use vortex_dtype::Nullability::NonNullable;
154    use vortex_dtype::PType::I32;
155    use vortex_dtype::PTypeDowncast;
156    use vortex_error::VortexExpect;
157    use vortex_mask::Mask;
158    use vortex_vector::Datum;
159    use vortex_vector::Scalar;
160    use vortex_vector::Vector;
161    use vortex_vector::VectorOps;
162    use vortex_vector::primitive::PScalar;
163    use vortex_vector::primitive::PVector;
164    use vortex_vector::primitive::PrimitiveScalar;
165    use vortex_vector::primitive::PrimitiveVector;
166
167    use crate::expr::Operator;
168    use crate::expr::functions::ExecutionArgs;
169    use crate::expr::functions::VTable;
170    use crate::scalar_fns::binary::BinaryFn;
171
172    #[test]
173    fn test_binary() {
174        let exec = ExecutionArgs::new(
175            100,
176            DType::Bool(NonNullable),
177            vec![I32.into(), I32.into()],
178            vec![
179                Datum::Scalar(Scalar::Primitive(PrimitiveScalar::I32(PScalar::new(Some(
180                    2i32,
181                ))))),
182                Datum::Scalar(Scalar::Primitive(PrimitiveScalar::I32(PScalar::new(Some(
183                    3i32,
184                ))))),
185            ],
186        );
187
188        let x = BinaryFn
189            .execute(&Operator::Gte, &exec)
190            .vortex_expect("shouldnt fail");
191        assert!(
192            !x.into_scalar()
193                .vortex_expect("")
194                .into_bool()
195                .value()
196                .vortex_expect("not null")
197        );
198        let x = BinaryFn
199            .execute(&Operator::Lt, &exec)
200            .vortex_expect("shouldnt fail");
201        assert!(
202            x.into_scalar()
203                .vortex_expect("")
204                .into_bool()
205                .value()
206                .vortex_expect("not null")
207        );
208    }
209
210    #[test]
211    fn test_add() {
212        let exec = ExecutionArgs::new(
213            3,
214            DType::Primitive(I32, NonNullable),
215            vec![I32.into(), I32.into()],
216            vec![
217                Datum::Scalar(Scalar::Primitive(PrimitiveScalar::I32(PScalar::new(Some(
218                    2i32,
219                ))))),
220                Datum::Vector(Vector::Primitive(PrimitiveVector::I32(PVector::new(
221                    buffer![1, 2, 3],
222                    Mask::AllTrue(3),
223                )))),
224            ],
225        );
226
227        let result = BinaryFn
228            .execute(&Operator::Add, &exec)
229            .vortex_expect("add should succeed");
230
231        let result_vec = result
232            .into_vector()
233            .vortex_expect("expected vector result")
234            .into_primitive();
235        let result_i32: PVector<i32> = result_vec.into_i32();
236        assert_eq!(result_i32.elements(), &buffer![3, 4, 5]);
237        assert_eq!(result_i32.validity(), &Mask::AllTrue(3));
238    }
239}