vegafusion_runtime/expression/compiler/
binary.rs

1use crate::expression::compiler::{compile, config::CompilationConfig};
2use datafusion_expr::expr::BinaryExpr;
3use datafusion_expr::{lit, when, Expr, Operator};
4use datafusion_functions::string::expr_fn::concat;
5use vegafusion_common::datafusion_common::DFSchema;
6use vegafusion_common::datatypes::{
7    cast_to, data_type, is_null_literal, is_numeric_datatype, is_string_datatype, to_numeric,
8    to_string,
9};
10
11use vegafusion_core::arrow::datatypes::DataType;
12use vegafusion_core::error::Result;
13use vegafusion_core::proto::gen::expression::{BinaryExpression, BinaryOperator};
14
15pub fn compile_binary(
16    node: &BinaryExpression,
17    config: &CompilationConfig,
18    schema: &DFSchema,
19) -> Result<Expr> {
20    // First, compile argument
21    let lhs = compile(node.left(), config, Some(schema))?;
22    let rhs = compile(node.right(), config, Some(schema))?;
23
24    let lhs_dtype = data_type(&lhs, schema)?;
25    let rhs_dtype = data_type(&rhs, schema)?;
26    let lhs_numeric = to_numeric(lhs.clone(), schema)?;
27    let rhs_numeric = to_numeric(rhs.clone(), schema)?;
28
29    let new_expr: Expr = match node.to_operator() {
30        BinaryOperator::Minus => Expr::BinaryExpr(BinaryExpr {
31            left: Box::new(lhs_numeric),
32            op: Operator::Minus,
33            right: Box::new(rhs_numeric),
34        }),
35        BinaryOperator::Mult => Expr::BinaryExpr(BinaryExpr {
36            left: Box::new(lhs_numeric),
37            op: Operator::Multiply,
38            right: Box::new(rhs_numeric),
39        }),
40        BinaryOperator::Div => Expr::BinaryExpr(BinaryExpr {
41            left: Box::new(cast_to(lhs_numeric, &DataType::Float64, schema)?),
42            op: Operator::Divide,
43            right: Box::new(rhs_numeric),
44        }),
45        BinaryOperator::Mod => Expr::BinaryExpr(BinaryExpr {
46            left: Box::new(lhs_numeric),
47            op: Operator::Modulo,
48            right: Box::new(rhs_numeric),
49        }),
50        BinaryOperator::LessThan => Expr::BinaryExpr(BinaryExpr {
51            left: Box::new(lhs_numeric),
52            op: Operator::Lt,
53            right: Box::new(rhs_numeric),
54        }),
55        BinaryOperator::LessThanEqual => Expr::BinaryExpr(BinaryExpr {
56            left: Box::new(lhs_numeric),
57            op: Operator::LtEq,
58            right: Box::new(rhs_numeric),
59        }),
60        BinaryOperator::GreaterThan => Expr::BinaryExpr(BinaryExpr {
61            left: Box::new(lhs_numeric),
62            op: Operator::Gt,
63            right: Box::new(rhs_numeric),
64        }),
65        BinaryOperator::GreaterThanEqual => Expr::BinaryExpr(BinaryExpr {
66            left: Box::new(lhs_numeric),
67            op: Operator::GtEq,
68            right: Box::new(rhs_numeric),
69        }),
70        BinaryOperator::StrictEquals => {
71            // Use original values, not those converted to numeric
72            // Let DataFusion handle numeric casting
73            if is_null_literal(&lhs) {
74                Expr::IsNull(Box::new(rhs))
75            } else if is_null_literal(&rhs) {
76                Expr::IsNull(Box::new(lhs))
77            } else if is_numeric_datatype(&lhs_dtype) && is_numeric_datatype(&rhs_dtype)
78                || lhs_dtype == rhs_dtype
79            {
80                Expr::BinaryExpr(BinaryExpr {
81                    left: Box::new(lhs),
82                    op: Operator::Eq,
83                    right: Box::new(rhs),
84                })
85            } else {
86                // Types are not compatible
87                lit(false)
88            }
89        }
90        BinaryOperator::NotStrictEquals => {
91            if is_null_literal(&lhs) {
92                Expr::IsNotNull(Box::new(rhs))
93            } else if is_null_literal(&rhs) {
94                Expr::IsNotNull(Box::new(lhs))
95            } else if is_numeric_datatype(&lhs_dtype) && is_numeric_datatype(&rhs_dtype)
96                || lhs_dtype == rhs_dtype
97            {
98                Expr::BinaryExpr(BinaryExpr {
99                    left: Box::new(lhs),
100                    op: Operator::NotEq,
101                    right: Box::new(rhs),
102                })
103            } else {
104                // Types are not compatible
105                lit(false)
106            }
107        }
108        BinaryOperator::Plus => {
109            if is_string_datatype(&lhs_dtype) || is_string_datatype(&rhs_dtype) {
110                // If either argument is a string, then both are treated as string and
111                // plus is string concatenation
112                let lhs_string = to_string(lhs, schema)?;
113                let rhs_string = to_string(rhs, schema)?;
114                concat(vec![lhs_string, rhs_string])
115            } else {
116                // Both sides are non-strings, use regular numeric plus operation
117                // Use result of to_numeric to handle booleans
118                Expr::BinaryExpr(BinaryExpr {
119                    left: Box::new(lhs_numeric),
120                    op: Operator::Plus,
121                    right: Box::new(rhs_numeric),
122                })
123            }
124        }
125        BinaryOperator::Equals => {
126            if is_null_literal(&lhs) {
127                Expr::IsNull(Box::new(rhs))
128            } else if is_null_literal(&rhs) {
129                Expr::IsNull(Box::new(lhs))
130            } else if is_string_datatype(&lhs_dtype) && is_string_datatype(&rhs_dtype) {
131                // Regular equality on strings
132                Expr::BinaryExpr(BinaryExpr {
133                    left: Box::new(lhs),
134                    op: Operator::Eq,
135                    right: Box::new(rhs),
136                })
137            } else {
138                // Both sides converted to numbers
139                Expr::BinaryExpr(BinaryExpr {
140                    left: Box::new(lhs_numeric),
141                    op: Operator::Eq,
142                    right: Box::new(rhs_numeric),
143                })
144            }
145            // TODO: if both null, then equal. If one null, then not equal
146        }
147        BinaryOperator::NotEquals => {
148            if is_null_literal(&lhs) {
149                Expr::IsNotNull(Box::new(rhs))
150            } else if is_null_literal(&rhs) {
151                Expr::IsNotNull(Box::new(lhs))
152            } else if is_string_datatype(&lhs_dtype) && is_string_datatype(&rhs_dtype) {
153                // Regular inequality on strings
154                Expr::BinaryExpr(BinaryExpr {
155                    left: Box::new(lhs),
156                    op: Operator::NotEq,
157                    right: Box::new(rhs),
158                })
159            } else {
160                // Both sides converted to numbers
161                // Both sides converted to numbers
162                Expr::BinaryExpr(BinaryExpr {
163                    left: Box::new(lhs_numeric),
164                    op: Operator::NotEq,
165                    right: Box::new(rhs_numeric),
166                })
167            }
168            // TODO: if both null, then equal. If one null, then not equal
169        }
170        BinaryOperator::BitwiseAnd => bitwise_expr(lhs, Operator::BitwiseAnd, rhs, schema)?,
171        BinaryOperator::BitwiseOr => bitwise_expr(lhs, Operator::BitwiseOr, rhs, schema)?,
172        BinaryOperator::BitwiseXor => bitwise_expr(lhs, Operator::BitwiseXor, rhs, schema)?,
173        BinaryOperator::BitwiseShiftLeft => {
174            bitwise_expr(lhs, Operator::BitwiseShiftLeft, rhs, schema)?
175        }
176        BinaryOperator::BitwiseShiftRight => {
177            bitwise_expr(lhs, Operator::BitwiseShiftRight, rhs, schema)?
178        }
179    };
180
181    Ok(new_expr)
182}
183
184fn bitwise_expr(lhs: Expr, op: Operator, rhs: Expr, schema: &DFSchema) -> Result<Expr> {
185    // Vega treats null as zero for bitwise operations
186    let left_cast = cast_to(lhs, &DataType::Int64, schema)?;
187    let right_cast = cast_to(rhs, &DataType::Int64, schema)?;
188    let left = when(left_cast.clone().is_null(), lit(0)).otherwise(left_cast)?;
189    let right = when(right_cast.clone().is_null(), lit(0)).otherwise(right_cast)?;
190
191    Ok(Expr::BinaryExpr(BinaryExpr {
192        left: Box::new(left),
193        op,
194        right: Box::new(right),
195    }))
196}