1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::ArrayRef;
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{AnalysisExpr, ExprRef, Identifier, Scope, ScopeDType, VortexExpr};
10
11#[allow(clippy::derived_hash_with_manual_eq)]
12#[derive(Debug, Eq, Hash)]
13pub struct Let {
16 var: Identifier,
17 bind: ExprRef,
18 expr: ExprRef,
19}
20
21impl Let {
22 pub fn new_expr(var: Identifier, bind: ExprRef, expr: ExprRef) -> ExprRef {
23 Arc::new(Self { var, bind, expr })
24 }
25}
26
27#[cfg(feature = "proto")]
28pub(crate) mod proto {
29 use vortex_error::{VortexResult, vortex_bail};
30 use vortex_proto::expr::kind;
31 use vortex_proto::expr::kind::Kind;
32
33 use crate::let_::Let;
34 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
35
36 pub(crate) struct LetSerde;
37
38 impl Id for LetSerde {
39 fn id(&self) -> &'static str {
40 "let"
41 }
42 }
43
44 impl ExprDeserialize for LetSerde {
45 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
46 let Kind::Let(op) = kind else {
47 vortex_bail!("wrong kind {:?}, wanted let", kind)
48 };
49
50 Ok(Let::new_expr(
51 op.var.clone().parse()?,
52 children[0].clone(),
53 children[1].clone(),
54 ))
55 }
56 }
57
58 impl ExprSerializable for Let {
59 fn id(&self) -> &'static str {
60 LetSerde.id()
61 }
62
63 fn serialize_kind(&self) -> VortexResult<Kind> {
64 Ok(Kind::Identity(kind::Identity {}))
65 }
66 }
67}
68
69impl Display for Let {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "let {} = {} in {}", self.var, self.bind, self.expr)
72 }
73}
74
75impl AnalysisExpr for Let {}
76
77impl VortexExpr for Let {
78 fn as_any(&self) -> &dyn Any {
79 self
80 }
81
82 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
83 let v = self.bind.unchecked_evaluate(scope)?;
84 let ctx_p = scope.copy_with_value(self.var.clone(), v);
85 self.expr.unchecked_evaluate(&ctx_p)
86 }
87
88 fn children(&self) -> Vec<&ExprRef> {
89 vec![&self.bind, &self.expr]
90 }
91
92 fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
93 assert_eq!(children.len(), 2);
94 let expr = children.remove(1);
95 let bind = children.remove(0);
96 Let::new_expr(self.var.clone(), bind, expr)
97 }
98
99 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
100 let v = self.bind.return_dtype(scope)?;
101 let ctx_p = scope.copy_with_value(self.var.clone(), v);
102 self.expr.return_dtype(&ctx_p)
103 }
104}
105
106impl PartialEq for Let {
107 fn eq(&self, other: &Let) -> bool {
108 self.var == other.var && self.bind.eq(&other.bind) && self.expr.eq(&other.expr)
109 }
110}
111
112pub fn let_(ident: Identifier, bind: ExprRef, expr: ExprRef) -> ExprRef {
113 Let::new_expr(ident, bind, expr)
114}
115
116#[cfg(test)]
117mod tests {
118 use itertools::Itertools;
119 use vortex_array::ToCanonical;
120 use vortex_array::arrays::{PrimitiveArray, StructArray};
121 use vortex_array::validity::Validity;
122 use vortex_buffer::buffer;
123
124 use crate::{Scope, eq, get_item_scope, let_, var};
125
126 #[test]
127 fn test_two_vars() {
128 let a1 = PrimitiveArray::new(buffer![5, 4, 3, 2, 1, 0], Validity::AllValid).to_array();
129 let a2 = PrimitiveArray::from_iter(1..=6).to_array();
130
131 let struct_arr = StructArray::from_fields(&[("a1", a1), ("a2", a2)])
132 .unwrap()
133 .to_array();
134
135 let expr = let_(
136 "x".parse().unwrap(),
137 get_item_scope("a1"),
138 let_(
139 "y".parse().unwrap(),
140 get_item_scope("a2"),
141 eq(var("x".parse().unwrap()), var("y".parse().unwrap())),
142 ),
143 );
144 let res = expr.evaluate(&Scope::new(struct_arr)).unwrap();
145 let res = res.to_bool().unwrap().boolean_buffer().iter().collect_vec();
146
147 assert_eq!(res, vec![false, false, true, false, false, false])
148 }
149}