vegafusion_core/expression/ast/
expression.rs

1use crate::error::{Result, VegaFusionError};
2use crate::expression::column_usage::{
3    DatasetsColumnUsage, GetDatasetsColumnUsage, VlSelectionFields,
4};
5use crate::expression::visitors::{
6    CheckSupportedExprVisitor, ClearSpansVisitor, DatasetsColumnUsageVisitor, ExpressionVisitor,
7    GetInputVariablesVisitor, ImplicitVariablesExprVisitor, MutExpressionVisitor,
8    UpdateVariablesExprVisitor,
9};
10use crate::proto::gen::expression::expression::Expr;
11use crate::proto::gen::expression::{
12    literal, ArrayExpression, BinaryExpression, CallExpression, ConditionalExpression, Expression,
13    Identifier, Literal, LogicalExpression, MemberExpression, ObjectExpression, Span,
14    UnaryExpression,
15};
16use crate::proto::gen::tasks::Variable;
17use crate::task_graph::graph::ScopedVariable;
18use crate::task_graph::scope::TaskScope;
19use crate::task_graph::task::InputVariable;
20use itertools::sorted;
21use std::fmt::{Display, Formatter};
22use std::ops::Deref;
23
24/// Trait that all AST node types implement
25pub trait ExpressionTrait: Display {
26    /// Get the left and right binding power of this expression.
27    /// When there is ambiguity in associativity, the expression with the lower binding power
28    /// must be parenthesized
29    fn binding_power(&self) -> (f64, f64) {
30        (1000.0, 1000.0)
31    }
32}
33
34impl Deref for Expression {
35    type Target = dyn ExpressionTrait;
36
37    fn deref(&self) -> &Self::Target {
38        match self.expr.as_ref().unwrap() {
39            Expr::Identifier(expr) => expr,
40            Expr::Literal(expr) => expr,
41            Expr::Binary(expr) => expr.as_ref(),
42            Expr::Logical(expr) => expr.as_ref(),
43            Expr::Unary(expr) => expr.as_ref(),
44            Expr::Conditional(expr) => expr.as_ref(),
45            Expr::Call(expr) => expr,
46            Expr::Array(expr) => expr,
47            Expr::Object(expr) => expr,
48            Expr::Member(expr) => expr.as_ref(),
49        }
50    }
51}
52
53impl Display for Expression {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        let expr = self.deref();
56        write!(f, "{expr}")
57    }
58}
59
60impl Expression {
61    pub fn new(expr: Expr, span: Option<Span>) -> Self {
62        Self {
63            expr: Some(expr),
64            span,
65        }
66    }
67
68    pub fn clear_spans(&mut self) {
69        let mut visitor = ClearSpansVisitor::new();
70        self.walk_mut(&mut visitor);
71    }
72
73    pub fn input_vars(&self) -> Vec<InputVariable> {
74        let mut visitor = GetInputVariablesVisitor::new();
75        self.walk(&mut visitor);
76
77        sorted(visitor.input_variables).collect()
78    }
79
80    pub fn update_vars(&self) -> Vec<Variable> {
81        let mut visitor = UpdateVariablesExprVisitor::new();
82        self.walk(&mut visitor);
83
84        sorted(visitor.update_variables).collect()
85    }
86
87    pub fn implicit_vars(&self) -> Vec<String> {
88        let mut visitor = ImplicitVariablesExprVisitor::new();
89        self.walk(&mut visitor);
90        sorted(visitor.implicit_vars).collect()
91    }
92
93    pub fn is_supported(&self) -> bool {
94        let mut visitor = CheckSupportedExprVisitor::new();
95        self.walk(&mut visitor);
96        visitor.supported
97    }
98
99    /// Walk visitor through the expression tree in a DFS traversal
100    pub fn walk(&self, visitor: &mut dyn ExpressionVisitor) {
101        match self.expr.as_ref().unwrap() {
102            Expr::Binary(node) => {
103                node.left().walk(visitor);
104                node.right().walk(visitor);
105                visitor.visit_binary(node);
106            }
107            Expr::Logical(node) => {
108                node.left().walk(visitor);
109                node.right().walk(visitor);
110                visitor.visit_logical(node);
111            }
112            Expr::Unary(node) => {
113                node.argument().walk(visitor);
114                visitor.visit_unary(node);
115            }
116            Expr::Conditional(node) => {
117                node.test().walk(visitor);
118                node.consequent().walk(visitor);
119                node.alternate().walk(visitor);
120                visitor.visit_conditional(node);
121            }
122            Expr::Literal(node) => {
123                visitor.visit_literal(node);
124            }
125            Expr::Identifier(node) => {
126                visitor.visit_identifier(node);
127            }
128            Expr::Call(node) => {
129                let callee_id = Identifier {
130                    name: node.callee.clone(),
131                };
132                visitor.visit_called_identifier(&callee_id, &node.arguments);
133                for arg in &node.arguments {
134                    arg.walk(visitor);
135                }
136                visitor.visit_call(node);
137            }
138            Expr::Array(node) => {
139                for el in &node.elements {
140                    el.walk(visitor);
141                }
142                visitor.visit_array(node);
143            }
144            Expr::Object(node) => {
145                for prop in &node.properties {
146                    visitor.visit_object_key(prop.key.as_ref().unwrap());
147                    prop.value.as_ref().unwrap().walk(visitor);
148                }
149                visitor.visit_object(node);
150            }
151            Expr::Member(node) => {
152                node.object.as_ref().unwrap().walk(visitor);
153                let prop_expr = node.property.as_ref().unwrap().expr.as_ref().unwrap();
154                if let Expr::Identifier(identifier) = prop_expr {
155                    visitor.visit_static_member_identifier(identifier);
156                } else {
157                    node.property.as_ref().unwrap().walk(visitor);
158                }
159                visitor.visit_member(node);
160            }
161        }
162        visitor.visit_expression(self);
163    }
164
165    pub fn walk_mut(&mut self, visitor: &mut dyn MutExpressionVisitor) {
166        match self.expr.as_mut().unwrap() {
167            Expr::Binary(node) => {
168                node.left.as_mut().unwrap().walk_mut(visitor);
169                node.right.as_mut().unwrap().walk_mut(visitor);
170                visitor.visit_binary(node);
171            }
172            Expr::Logical(node) => {
173                node.left.as_mut().unwrap().walk_mut(visitor);
174                node.right.as_mut().unwrap().walk_mut(visitor);
175                visitor.visit_logical(node);
176            }
177            Expr::Unary(node) => {
178                node.argument.as_mut().unwrap().walk_mut(visitor);
179                visitor.visit_unary(node);
180            }
181            Expr::Conditional(node) => {
182                node.test.as_mut().unwrap().walk_mut(visitor);
183                node.consequent.as_mut().unwrap().walk_mut(visitor);
184                node.alternate.as_mut().unwrap().walk_mut(visitor);
185                visitor.visit_conditional(node);
186            }
187            Expr::Literal(node) => {
188                visitor.visit_literal(node);
189            }
190            Expr::Identifier(node) => {
191                visitor.visit_identifier(node);
192            }
193            Expr::Call(node) => {
194                let mut callee_id = Identifier {
195                    name: node.callee.clone(),
196                };
197                visitor.visit_called_identifier(&mut callee_id, &mut node.arguments);
198                for arg in &mut node.arguments {
199                    arg.walk_mut(visitor);
200                }
201                visitor.visit_call(node);
202            }
203            Expr::Array(node) => {
204                for el in &mut node.elements {
205                    el.walk_mut(visitor);
206                }
207                visitor.visit_array(node);
208            }
209            Expr::Object(node) => {
210                for prop in &mut node.properties {
211                    visitor.visit_object_key(prop.key.as_mut().unwrap());
212                    prop.value.as_mut().unwrap().walk_mut(visitor);
213                }
214                visitor.visit_object(node);
215            }
216            Expr::Member(node) => {
217                node.object.as_mut().unwrap().walk_mut(visitor);
218                let prop_expr = node.property.as_mut().unwrap().expr.as_mut().unwrap();
219                if let Expr::Identifier(identifier) = prop_expr {
220                    visitor.visit_static_member_identifier(identifier);
221                } else {
222                    node.property.as_mut().unwrap().walk_mut(visitor);
223                }
224                visitor.visit_member(node);
225            }
226        }
227        visitor.visit_expression(self);
228    }
229
230    pub fn as_identifier(&self) -> Result<&Identifier> {
231        match &self.expr {
232            Some(Expr::Identifier(identifier)) => Ok(identifier),
233            _ => Err(VegaFusionError::internal("Expression is not an identifier")),
234        }
235    }
236
237    pub fn as_literal(&self) -> Result<&Literal> {
238        match &self.expr {
239            Some(Expr::Literal(value)) => Ok(value),
240            _ => Err(VegaFusionError::internal("Expression is not a Literal")),
241        }
242    }
243
244    pub fn expr(&self) -> &Expr {
245        self.expr.as_ref().unwrap()
246    }
247}
248
249// Expression from literal
250impl<V: Into<literal::Value>> From<V> for Expression {
251    fn from(v: V) -> Self {
252        Self {
253            expr: Some(Expr::from(v)),
254            span: None,
255        }
256    }
257}
258
259// Expr conversions
260impl From<Literal> for Expr {
261    fn from(v: Literal) -> Self {
262        Self::Literal(v)
263    }
264}
265
266impl From<Identifier> for Expr {
267    fn from(v: Identifier) -> Self {
268        Self::Identifier(v)
269    }
270}
271
272impl From<UnaryExpression> for Expr {
273    fn from(v: UnaryExpression) -> Self {
274        Self::Unary(Box::new(v))
275    }
276}
277
278impl From<BinaryExpression> for Expr {
279    fn from(v: BinaryExpression) -> Self {
280        Self::Binary(Box::new(v))
281    }
282}
283
284impl From<LogicalExpression> for Expr {
285    fn from(v: LogicalExpression) -> Self {
286        Self::Logical(Box::new(v))
287    }
288}
289
290impl From<CallExpression> for Expr {
291    fn from(v: CallExpression) -> Self {
292        Self::Call(v)
293    }
294}
295
296impl From<MemberExpression> for Expr {
297    fn from(v: MemberExpression) -> Self {
298        Self::Member(Box::new(v))
299    }
300}
301
302impl From<ConditionalExpression> for Expr {
303    fn from(v: ConditionalExpression) -> Self {
304        Self::Conditional(Box::new(v))
305    }
306}
307
308impl From<ArrayExpression> for Expr {
309    fn from(v: ArrayExpression) -> Self {
310        Self::Array(v)
311    }
312}
313
314impl From<ObjectExpression> for Expr {
315    fn from(v: ObjectExpression) -> Self {
316        Self::Object(v)
317    }
318}
319
320impl<V: Into<literal::Value>> From<V> for Expr {
321    fn from(v: V) -> Self {
322        let v = v.into();
323        let repr = v.to_string();
324        Self::Literal(Literal::new(v, &repr))
325    }
326}
327
328impl GetDatasetsColumnUsage for Expression {
329    fn datasets_column_usage(
330        &self,
331        datum_var: &Option<ScopedVariable>,
332        usage_scope: &[u32],
333        task_scope: &TaskScope,
334        vl_selection_fields: &VlSelectionFields,
335    ) -> DatasetsColumnUsage {
336        let mut visitor = DatasetsColumnUsageVisitor::new(
337            datum_var,
338            usage_scope,
339            task_scope,
340            vl_selection_fields,
341        );
342        self.walk(&mut visitor);
343        visitor.dataset_column_usage
344    }
345}