vortex_expr/exprs/
get_item.rs1use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6
7use vortex_array::stats::Stat;
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, ToCanonical};
9use vortex_dtype::{DType, FieldName, FieldPath};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{
15 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, root,
16 vtable,
17};
18
19vtable!(GetItem);
20
21#[allow(clippy::derived_hash_with_manual_eq)]
22#[derive(Debug, Clone, Hash, Eq)]
23pub struct GetItemExpr {
24 field: FieldName,
25 child: ExprRef,
26}
27
28impl PartialEq for GetItemExpr {
29 fn eq(&self, other: &Self) -> bool {
30 self.field == other.field && self.child.eq(&other.child)
31 }
32}
33
34pub struct GetItemExprEncoding;
35
36impl VTable for GetItemVTable {
37 type Expr = GetItemExpr;
38 type Encoding = GetItemExprEncoding;
39 type Metadata = ProstMetadata<pb::GetItemOpts>;
40
41 fn id(_encoding: &Self::Encoding) -> ExprId {
42 ExprId::new_ref("get_item")
43 }
44
45 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46 ExprEncodingRef::new_ref(GetItemExprEncoding.as_ref())
47 }
48
49 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50 Some(ProstMetadata(pb::GetItemOpts {
51 path: expr.field.to_string(),
52 }))
53 }
54
55 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56 vec![&expr.child]
57 }
58
59 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60 Ok(GetItemExpr {
61 field: expr.field.clone(),
62 child: children[0].clone(),
63 })
64 }
65
66 fn build(
67 _encoding: &Self::Encoding,
68 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
69 children: Vec<ExprRef>,
70 ) -> VortexResult<Self::Expr> {
71 if children.len() != 1 {
72 vortex_bail!(
73 "GetItem expression must have exactly 1 child, got {}",
74 children.len()
75 );
76 }
77
78 let field = FieldName::from(metadata.path.clone());
79 Ok(GetItemExpr {
80 field,
81 child: children[0].clone(),
82 })
83 }
84
85 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
86 expr.child
87 .unchecked_evaluate(scope)?
88 .to_struct()?
89 .field_by_name(expr.field())
90 .cloned()
91 }
92
93 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
94 let input = expr.child.return_dtype(scope)?;
95 input
96 .as_struct_fields_opt()
97 .and_then(|st| st.field(expr.field()))
98 .ok_or_else(|| {
99 vortex_err!(
100 "Couldn't find the {} field in the input scope",
101 expr.field()
102 )
103 })
104 }
105}
106
107impl GetItemExpr {
108 pub fn new(field: impl Into<FieldName>, child: ExprRef) -> Self {
109 Self {
110 field: field.into(),
111 child,
112 }
113 }
114
115 pub fn new_expr(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
116 Self::new(field, child).into_expr()
117 }
118
119 pub fn field(&self) -> &FieldName {
120 &self.field
121 }
122
123 pub fn child(&self) -> &ExprRef {
124 &self.child
125 }
126
127 pub fn is(expr: &ExprRef) -> bool {
128 expr.is::<GetItemVTable>()
129 }
130}
131
132pub fn col(field: impl Into<FieldName>) -> ExprRef {
141 GetItemExpr::new(field, root()).into_expr()
142}
143
144pub fn get_item(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
153 GetItemExpr::new(field, child).into_expr()
154}
155
156impl DisplayAs for GetItemExpr {
157 fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
158 match df {
159 DisplayFormat::Compact => {
160 write!(f, "{}.{}", self.child, &self.field)
161 }
162 DisplayFormat::Tree => {
163 write!(f, "GetItem({})", self.field)
164 }
165 }
166 }
167}
168impl AnalysisExpr for GetItemExpr {
169 fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
170 catalog.stats_ref(&self.field_path()?, Stat::Max)
171 }
172
173 fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
174 catalog.stats_ref(&self.field_path()?, Stat::Min)
175 }
176
177 fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
178 catalog.stats_ref(&self.field_path()?, Stat::NaNCount)
179 }
180
181 fn field_path(&self) -> Option<FieldPath> {
182 self.child()
183 .field_path()
184 .map(|fp| fp.push(self.field.clone()))
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use vortex_array::IntoArray;
191 use vortex_array::arrays::StructArray;
192 use vortex_buffer::buffer;
193 use vortex_dtype::DType;
194 use vortex_dtype::PType::I32;
195
196 use crate::get_item::get_item;
197 use crate::{Scope, root};
198
199 fn test_array() -> StructArray {
200 StructArray::from_fields(&[
201 ("a", buffer![0i32, 1, 2].into_array()),
202 ("b", buffer![4i64, 5, 6].into_array()),
203 ])
204 .unwrap()
205 }
206
207 #[test]
208 pub fn get_item_by_name() {
209 let st = test_array();
210 let get_item = get_item("a", root());
211 let item = get_item.evaluate(&Scope::new(st.to_array())).unwrap();
212 assert_eq!(item.dtype(), &DType::from(I32))
213 }
214
215 #[test]
216 pub fn get_item_by_name_none() {
217 let st = test_array();
218 let get_item = get_item("c", root());
219 assert!(get_item.evaluate(&Scope::new(st.to_array())).is_err());
220 }
221}