1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::ArrayRef;
6use vortex_array::stats::Stat;
7use vortex_dtype::{DType, FieldPath};
8use vortex_error::{VortexResult, vortex_err};
9
10use crate::{
11 AccessPath, AnalysisExpr, ExprRef, Identifier, Scope, ScopeDType, StatsCatalog, VortexExpr,
12};
13
14#[derive(Debug, PartialEq, Eq, Hash)]
15pub struct Var {
16 var: Identifier,
17}
18
19impl Var {
22 pub fn new_expr(var: Identifier) -> ExprRef {
23 Arc::new(Self { var })
24 }
25
26 pub fn var(&self) -> &Identifier {
27 &self.var
28 }
29}
30
31#[cfg(feature = "proto")]
32pub(crate) mod proto {
33 use vortex_error::{VortexResult, vortex_bail};
34 use vortex_proto::expr::kind::{Kind, Var as ProtoVar};
35
36 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Var, root};
37
38 pub(crate) struct IdentitySerde;
41
42 impl Id for IdentitySerde {
43 fn id(&self) -> &'static str {
44 "identity"
45 }
46 }
47
48 impl ExprDeserialize for IdentitySerde {
49 fn deserialize(&self, kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
50 let Kind::Identity(..) = kind else {
51 vortex_bail!("wrong kind {:?}, wanted identity", kind)
52 };
53
54 Ok(root())
55 }
56 }
57
58 pub(crate) struct VarSerde;
59
60 impl Id for VarSerde {
61 fn id(&self) -> &'static str {
62 "var"
63 }
64 }
65
66 impl ExprDeserialize for VarSerde {
67 fn deserialize(&self, kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
68 let Kind::Var(op) = kind else {
69 vortex_bail!("wrong kind {:?}, wanted var", kind)
70 };
71
72 match op.var.as_str() {
73 "" => Ok(Var::new_expr(crate::Identifier::Identity)),
74 other => Ok(Var::new_expr(other.parse()?)),
75 }
76 }
77 }
78
79 impl ExprSerializable for Var {
80 fn id(&self) -> &'static str {
81 VarSerde.id()
82 }
83
84 fn serialize_kind(&self) -> VortexResult<Kind> {
85 Ok(Kind::Var(ProtoVar {
86 var: self.var.to_string(),
87 }))
88 }
89 }
90}
91
92impl Display for Var {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "${}", self.var)
95 }
96}
97
98impl AnalysisExpr for Var {
99 fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
100 catalog.stats_ref(&self.field_path()?, Stat::Max)
101 }
102
103 fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
104 catalog.stats_ref(&self.field_path()?, Stat::Min)
105 }
106
107 fn field_path(&self) -> Option<AccessPath> {
108 Some(AccessPath::new(FieldPath::root(), self.var.clone()))
109 }
110}
111
112impl VortexExpr for Var {
113 fn as_any(&self) -> &dyn Any {
114 self
115 }
116
117 fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
118 ctx.array(&self.var)
119 .cloned()
120 .ok_or_else(|| vortex_err!("cannot find '{}' in arrays scope", self.var))
121 }
122
123 fn children(&self) -> Vec<&ExprRef> {
124 vec![]
125 }
126
127 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
128 assert_eq!(children.len(), 0);
129 Var::new_expr(self.var.clone())
130 }
131
132 fn return_dtype(&self, dt_ctx: &ScopeDType) -> VortexResult<DType> {
133 dt_ctx
134 .dtype(&self.var)
135 .cloned()
136 .ok_or_else(|| vortex_err!("cannot find '{}' in dtype scope", self.var))
137 }
138}
139
140pub fn var(ident: Identifier) -> ExprRef {
141 Var::new_expr(ident)
142}
143
144pub fn root() -> ExprRef {
147 Var::new_expr(Identifier::Identity)
148}
149
150pub fn is_root(expr: &ExprRef) -> bool {
151 expr.as_any()
152 .downcast_ref::<Var>()
153 .is_some_and(|v| v.var().is_identity())
154}
155
156#[cfg(test)]
157mod tests {
158 use std::str::FromStr;
159
160 use itertools::Itertools;
161 use vortex_array::ToCanonical;
162 use vortex_array::arrays::PrimitiveArray;
163 use vortex_array::validity::Validity;
164 use vortex_buffer::buffer;
165
166 use crate::{Identifier, Scope, eq, var};
167
168 #[test]
169 fn test_two_vars() {
170 let a1 = PrimitiveArray::new(buffer![5, 4, 3, 2, 1, 0], Validity::AllValid).to_array();
171 let a2 = PrimitiveArray::from_iter(1..=6).to_array();
172
173 let expr = eq(var(Identifier::Identity), var("row".parse().unwrap()));
174 let res = expr
175 .evaluate(&Scope::new(a1).with_array("row".parse().unwrap(), a2))
176 .unwrap();
177 let res = res.to_bool().unwrap().boolean_buffer().iter().collect_vec();
178
179 assert_eq!(res, vec![false, false, true, false, false, false])
180 }
181
182 #[test]
183 fn test_empty_string_ident_not_allowed() {
184 assert!(Identifier::from_str("").is_err());
185 }
186}