vortex_array/scalar_fns/binary/
mod.rs1use prost::Message;
5use vortex_compute::arithmetic::Add;
6use vortex_compute::arithmetic::Arithmetic;
7use vortex_compute::arithmetic::CheckedArithmetic;
8use vortex_compute::arithmetic::Div;
9use vortex_compute::arithmetic::Mul;
10use vortex_compute::arithmetic::Sub;
11use vortex_compute::comparison::Compare;
12use vortex_compute::comparison::Equal;
13use vortex_compute::comparison::GreaterThan;
14use vortex_compute::comparison::GreaterThanOrEqual;
15use vortex_compute::comparison::LessThan;
16use vortex_compute::comparison::LessThanOrEqual;
17use vortex_compute::comparison::NotEqual;
18use vortex_compute::logical::KleeneAnd;
19use vortex_compute::logical::KleeneOr;
20use vortex_compute::logical::LogicalOp;
21use vortex_dtype::DType;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_error::vortex_err;
25use vortex_proto::expr as pb;
26use vortex_vector::BoolDatum;
27use vortex_vector::Datum;
28use vortex_vector::PrimitiveDatum;
29
30use crate::expr::ChildName;
31use crate::expr::Operator;
32use crate::expr::functions::ArgName;
33use crate::expr::functions::Arity;
34use crate::expr::functions::ExecutionArgs;
35use crate::expr::functions::FunctionId;
36use crate::expr::functions::NullHandling;
37use crate::expr::functions::VTable;
38
39pub struct BinaryFn;
40impl VTable for BinaryFn {
41 type Options = Operator;
42
43 fn id(&self) -> FunctionId {
44 FunctionId::from("vortex.binary")
45 }
46
47 fn serialize(&self, op: &Operator) -> VortexResult<Option<Vec<u8>>> {
48 Ok(Some(pb::BinaryOpts { op: (*op).into() }.encode_to_vec()))
49 }
50
51 fn deserialize(&self, bytes: &[u8]) -> VortexResult<Operator> {
52 let opts = pb::BinaryOpts::decode(bytes)?;
53 Operator::try_from(opts.op)
54 }
55
56 fn arity(&self, _options: &Operator) -> Arity {
57 Arity::Exact(2)
58 }
59
60 fn null_handling(&self, options: &Operator) -> NullHandling {
61 match options {
62 Operator::And | Operator::Or => NullHandling::AbsorbsNull,
63 _ => NullHandling::Propagate,
64 }
65 }
66
67 fn arg_name(&self, _options: &Operator, arg_idx: usize) -> ArgName {
68 match arg_idx {
69 0 => ChildName::from("lhs"),
70 1 => ChildName::from("rhs"),
71 _ => unreachable!("Binary has only two arguments"),
72 }
73 }
74
75 fn return_dtype(&self, options: &Operator, arg_types: &[DType]) -> VortexResult<DType> {
76 let lhs = &arg_types[0];
77 let rhs = &arg_types[1];
78
79 if options.is_arithmetic() {
80 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
81 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
82 }
83 vortex_bail!(
84 "incompatible types for arithmetic operation: {} {}",
85 lhs,
86 rhs
87 );
88 }
89
90 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
91 }
92
93 fn execute(&self, op: &Operator, args: &ExecutionArgs) -> VortexResult<Datum> {
94 let lhs: Datum = args.input_datums(0).clone();
95 let rhs: Datum = args.input_datums(1).clone();
96
97 match op {
98 Operator::Eq => Ok(Compare::<Equal>::compare(lhs, rhs).into()),
99 Operator::NotEq => Ok(Compare::<NotEqual>::compare(lhs, rhs).into()),
100 Operator::Lt => Ok(Compare::<LessThan>::compare(lhs, rhs).into()),
101 Operator::Lte => Ok(Compare::<LessThanOrEqual>::compare(lhs, rhs).into()),
102 Operator::Gt => Ok(Compare::<GreaterThan>::compare(lhs, rhs).into()),
103 Operator::Gte => Ok(Compare::<GreaterThanOrEqual>::compare(lhs, rhs).into()),
104 Operator::And => Ok(<BoolDatum as LogicalOp<KleeneAnd>>::op(
105 lhs.into_bool(),
106 rhs.into_bool(),
107 )
108 .into()),
109 Operator::Or => {
110 Ok(<BoolDatum as LogicalOp<KleeneOr>>::op(lhs.into_bool(), rhs.into_bool()).into())
111 }
112 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
113 execute_arithmetic_primitive(lhs.into_primitive(), rhs.into_primitive(), *op)
114 }
115 }
116 }
117}
118
119fn execute_arithmetic_primitive(
120 lhs: PrimitiveDatum,
121 rhs: PrimitiveDatum,
122 op: Operator,
123) -> VortexResult<Datum> {
124 if lhs.ptype().is_float() && lhs.ptype() == rhs.ptype() {
126 let result: PrimitiveDatum = match op {
127 Operator::Add => Arithmetic::<Add>::eval(lhs, rhs),
128 Operator::Sub => Arithmetic::<Sub>::eval(lhs, rhs),
129 Operator::Mul => Arithmetic::<Mul>::eval(lhs, rhs),
130 Operator::Div => Arithmetic::<Div>::eval(lhs, rhs),
131 _ => unreachable!("Not an arithmetic operator"),
132 };
133 return Ok(result.into());
134 }
135
136 let result: Option<PrimitiveDatum> = match op {
138 Operator::Add => CheckedArithmetic::<Add>::checked_eval(lhs, rhs),
139 Operator::Sub => CheckedArithmetic::<Sub>::checked_eval(lhs, rhs),
140 Operator::Mul => CheckedArithmetic::<Mul>::checked_eval(lhs, rhs),
141 Operator::Div => CheckedArithmetic::<Div>::checked_eval(lhs, rhs),
142 _ => unreachable!("Not an arithmetic operator"),
143 };
144 result
145 .map(|d| d.into())
146 .ok_or_else(|| vortex_err!("Arithmetic overflow/underflow or type mismatch"))
147}
148
149#[cfg(test)]
150mod tests {
151 use vortex_buffer::buffer;
152 use vortex_dtype::DType;
153 use vortex_dtype::Nullability::NonNullable;
154 use vortex_dtype::PType::I32;
155 use vortex_dtype::PTypeDowncast;
156 use vortex_error::VortexExpect;
157 use vortex_mask::Mask;
158 use vortex_vector::Datum;
159 use vortex_vector::Scalar;
160 use vortex_vector::Vector;
161 use vortex_vector::VectorOps;
162 use vortex_vector::primitive::PScalar;
163 use vortex_vector::primitive::PVector;
164 use vortex_vector::primitive::PrimitiveScalar;
165 use vortex_vector::primitive::PrimitiveVector;
166
167 use crate::expr::Operator;
168 use crate::expr::functions::ExecutionArgs;
169 use crate::expr::functions::VTable;
170 use crate::scalar_fns::binary::BinaryFn;
171
172 #[test]
173 fn test_binary() {
174 let exec = ExecutionArgs::new(
175 100,
176 DType::Bool(NonNullable),
177 vec![I32.into(), I32.into()],
178 vec![
179 Datum::Scalar(Scalar::Primitive(PrimitiveScalar::I32(PScalar::new(Some(
180 2i32,
181 ))))),
182 Datum::Scalar(Scalar::Primitive(PrimitiveScalar::I32(PScalar::new(Some(
183 3i32,
184 ))))),
185 ],
186 );
187
188 let x = BinaryFn
189 .execute(&Operator::Gte, &exec)
190 .vortex_expect("shouldnt fail");
191 assert!(
192 !x.into_scalar()
193 .vortex_expect("")
194 .into_bool()
195 .value()
196 .vortex_expect("not null")
197 );
198 let x = BinaryFn
199 .execute(&Operator::Lt, &exec)
200 .vortex_expect("shouldnt fail");
201 assert!(
202 x.into_scalar()
203 .vortex_expect("")
204 .into_bool()
205 .value()
206 .vortex_expect("not null")
207 );
208 }
209
210 #[test]
211 fn test_add() {
212 let exec = ExecutionArgs::new(
213 3,
214 DType::Primitive(I32, NonNullable),
215 vec![I32.into(), I32.into()],
216 vec![
217 Datum::Scalar(Scalar::Primitive(PrimitiveScalar::I32(PScalar::new(Some(
218 2i32,
219 ))))),
220 Datum::Vector(Vector::Primitive(PrimitiveVector::I32(PVector::new(
221 buffer![1, 2, 3],
222 Mask::AllTrue(3),
223 )))),
224 ],
225 );
226
227 let result = BinaryFn
228 .execute(&Operator::Add, &exec)
229 .vortex_expect("add should succeed");
230
231 let result_vec = result
232 .into_vector()
233 .vortex_expect("expected vector result")
234 .into_primitive();
235 let result_i32: PVector<i32> = result_vec.into_i32();
236 assert_eq!(result_i32.elements(), &buffer![3, 4, 5]);
237 assert_eq!(result_i32.validity(), &Mask::AllTrue(3));
238 }
239}