vortex_expr/
let_.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::ArrayRef;
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{AnalysisExpr, ExprRef, Identifier, Scope, ScopeDType, VortexExpr};
10
11#[allow(clippy::derived_hash_with_manual_eq)]
12#[derive(Debug, Eq, Hash)]
13/// Let expressions are of the form `let var = bind in expr`,
14/// see `Scope`.
15pub struct Let {
16    var: Identifier,
17    bind: ExprRef,
18    expr: ExprRef,
19}
20
21impl Let {
22    pub fn new_expr(var: Identifier, bind: ExprRef, expr: ExprRef) -> ExprRef {
23        Arc::new(Self { var, bind, expr })
24    }
25}
26
27#[cfg(feature = "proto")]
28pub(crate) mod proto {
29    use vortex_error::{VortexResult, vortex_bail};
30    use vortex_proto::expr::kind;
31    use vortex_proto::expr::kind::Kind;
32
33    use crate::let_::Let;
34    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
35
36    pub(crate) struct LetSerde;
37
38    impl Id for LetSerde {
39        fn id(&self) -> &'static str {
40            "let"
41        }
42    }
43
44    impl ExprDeserialize for LetSerde {
45        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
46            let Kind::Let(op) = kind else {
47                vortex_bail!("wrong kind {:?}, wanted let", kind)
48            };
49
50            Ok(Let::new_expr(
51                op.var.clone().parse()?,
52                children[0].clone(),
53                children[1].clone(),
54            ))
55        }
56    }
57
58    impl ExprSerializable for Let {
59        fn id(&self) -> &'static str {
60            LetSerde.id()
61        }
62
63        fn serialize_kind(&self) -> VortexResult<Kind> {
64            Ok(Kind::Identity(kind::Identity {}))
65        }
66    }
67}
68
69impl Display for Let {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "let {} = {} in {}", self.var, self.bind, self.expr)
72    }
73}
74
75impl AnalysisExpr for Let {}
76
77impl VortexExpr for Let {
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
83        let v = self.bind.unchecked_evaluate(scope)?;
84        let ctx_p = scope.copy_with_value(self.var.clone(), v);
85        self.expr.unchecked_evaluate(&ctx_p)
86    }
87
88    fn children(&self) -> Vec<&ExprRef> {
89        vec![&self.bind, &self.expr]
90    }
91
92    fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
93        assert_eq!(children.len(), 2);
94        let expr = children.remove(1);
95        let bind = children.remove(0);
96        Let::new_expr(self.var.clone(), bind, expr)
97    }
98
99    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
100        let v = self.bind.return_dtype(scope)?;
101        let ctx_p = scope.copy_with_value(self.var.clone(), v);
102        self.expr.return_dtype(&ctx_p)
103    }
104}
105
106impl PartialEq for Let {
107    fn eq(&self, other: &Let) -> bool {
108        self.var == other.var && self.bind.eq(&other.bind) && self.expr.eq(&other.expr)
109    }
110}
111
112pub fn let_(ident: Identifier, bind: ExprRef, expr: ExprRef) -> ExprRef {
113    Let::new_expr(ident, bind, expr)
114}
115
116#[cfg(test)]
117mod tests {
118    use itertools::Itertools;
119    use vortex_array::ToCanonical;
120    use vortex_array::arrays::{PrimitiveArray, StructArray};
121    use vortex_array::validity::Validity;
122    use vortex_buffer::buffer;
123
124    use crate::{Scope, eq, get_item_scope, let_, var};
125
126    #[test]
127    fn test_two_vars() {
128        let a1 = PrimitiveArray::new(buffer![5, 4, 3, 2, 1, 0], Validity::AllValid).to_array();
129        let a2 = PrimitiveArray::from_iter(1..=6).to_array();
130
131        let struct_arr = StructArray::from_fields(&[("a1", a1), ("a2", a2)])
132            .unwrap()
133            .to_array();
134
135        let expr = let_(
136            "x".parse().unwrap(),
137            get_item_scope("a1"),
138            let_(
139                "y".parse().unwrap(),
140                get_item_scope("a2"),
141                eq(var("x".parse().unwrap()), var("y".parse().unwrap())),
142            ),
143        );
144        let res = expr.evaluate(&Scope::new(struct_arr)).unwrap();
145        let res = res.to_bool().unwrap().boolean_buffer().iter().collect_vec();
146
147        assert_eq!(res, vec![false, false, true, false, false, false])
148    }
149}