1use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6use std::ops::Not;
7
8use vortex_array::compute::mask;
9use vortex_array::stats::Stat;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, ToCanonical};
11use vortex_dtype::{DType, FieldName, FieldPath, Nullability};
12use vortex_error::{VortexResult, vortex_bail, vortex_err};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, root,
18 vtable,
19};
20
21vtable!(GetItem);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct GetItemExpr {
26 field: FieldName,
27 child: ExprRef,
28}
29
30impl PartialEq for GetItemExpr {
31 fn eq(&self, other: &Self) -> bool {
32 self.field == other.field && self.child.eq(&other.child)
33 }
34}
35
36pub struct GetItemExprEncoding;
37
38impl VTable for GetItemVTable {
39 type Expr = GetItemExpr;
40 type Encoding = GetItemExprEncoding;
41 type Metadata = ProstMetadata<pb::GetItemOpts>;
42
43 fn id(_encoding: &Self::Encoding) -> ExprId {
44 ExprId::new_ref("get_item")
45 }
46
47 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
48 ExprEncodingRef::new_ref(GetItemExprEncoding.as_ref())
49 }
50
51 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
52 Some(ProstMetadata(pb::GetItemOpts {
53 path: expr.field.to_string(),
54 }))
55 }
56
57 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
58 vec![&expr.child]
59 }
60
61 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
62 Ok(GetItemExpr {
63 field: expr.field.clone(),
64 child: children[0].clone(),
65 })
66 }
67
68 fn build(
69 _encoding: &Self::Encoding,
70 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
71 children: Vec<ExprRef>,
72 ) -> VortexResult<Self::Expr> {
73 if children.len() != 1 {
74 vortex_bail!(
75 "GetItem expression must have exactly 1 child, got {}",
76 children.len()
77 );
78 }
79
80 let field = FieldName::from(metadata.path.clone());
81 Ok(GetItemExpr {
82 field,
83 child: children[0].clone(),
84 })
85 }
86
87 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
88 let input = expr.child.unchecked_evaluate(scope)?.to_struct();
89 let field = input.field_by_name(expr.field()).cloned()?;
90
91 match input.dtype().nullability() {
92 Nullability::NonNullable => Ok(field),
93 Nullability::Nullable => mask(&field, &input.validity_mask().not()),
94 }
95 }
96
97 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
98 let input = expr.child.return_dtype(scope)?;
99 input
100 .as_struct_fields_opt()
101 .and_then(|st| st.field(expr.field()))
102 .map(|f| f.union_nullability(input.nullability()))
103 .ok_or_else(|| {
104 vortex_err!(
105 "Couldn't find the {} field in the input scope",
106 expr.field()
107 )
108 })
109 }
110}
111
112impl GetItemExpr {
113 pub fn new(field: impl Into<FieldName>, child: ExprRef) -> Self {
114 Self {
115 field: field.into(),
116 child,
117 }
118 }
119
120 pub fn new_expr(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
121 Self::new(field, child).into_expr()
122 }
123
124 pub fn field(&self) -> &FieldName {
125 &self.field
126 }
127
128 pub fn child(&self) -> &ExprRef {
129 &self.child
130 }
131
132 pub fn is(expr: &ExprRef) -> bool {
133 expr.is::<GetItemVTable>()
134 }
135}
136
137pub fn col(field: impl Into<FieldName>) -> ExprRef {
146 GetItemExpr::new(field, root()).into_expr()
147}
148
149pub fn get_item(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
158 GetItemExpr::new(field, child).into_expr()
159}
160
161impl DisplayAs for GetItemExpr {
162 fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
163 match df {
164 DisplayFormat::Compact => {
165 write!(f, "{}.{}", self.child, &self.field)
166 }
167 DisplayFormat::Tree => {
168 write!(f, "GetItem({})", self.field)
169 }
170 }
171 }
172}
173impl AnalysisExpr for GetItemExpr {
174 fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
175 catalog.stats_ref(&self.field_path()?, Stat::Max)
176 }
177
178 fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
179 catalog.stats_ref(&self.field_path()?, Stat::Min)
180 }
181
182 fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
183 catalog.stats_ref(&self.field_path()?, Stat::NaNCount)
184 }
185
186 fn field_path(&self) -> Option<FieldPath> {
187 self.child()
188 .field_path()
189 .map(|fp| fp.push(self.field.clone()))
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use vortex_array::arrays::{PrimitiveArray, StructArray};
196 use vortex_array::validity::Validity;
197 use vortex_array::{Array, IntoArray};
198 use vortex_buffer::buffer;
199 use vortex_dtype::PType::I32;
200 use vortex_dtype::{DType, FieldNames, Nullability};
201 use vortex_scalar::Scalar;
202
203 use crate::get_item::get_item;
204 use crate::{Scope, root};
205
206 fn test_array() -> StructArray {
207 StructArray::from_fields(&[
208 ("a", buffer![0i32, 1, 2].into_array()),
209 ("b", buffer![4i64, 5, 6].into_array()),
210 ])
211 .unwrap()
212 }
213
214 #[test]
215 fn get_item_by_name() {
216 let st = test_array();
217 let get_item = get_item("a", root());
218 let item = get_item.evaluate(&Scope::new(st.to_array())).unwrap();
219 assert_eq!(item.dtype(), &DType::from(I32))
220 }
221
222 #[test]
223 fn get_item_by_name_none() {
224 let st = test_array();
225 let get_item = get_item("c", root());
226 assert!(get_item.evaluate(&Scope::new(st.to_array())).is_err());
227 }
228
229 #[test]
230 fn get_nullable_field() {
231 let st = StructArray::try_new(
232 FieldNames::from(["a"]),
233 vec![PrimitiveArray::from_iter([1i32]).to_array()],
234 1,
235 Validity::AllInvalid,
236 )
237 .unwrap()
238 .to_array();
239
240 let get_item = get_item("a", root());
241 let item = get_item.evaluate(&Scope::new(st)).unwrap();
242 assert_eq!(
243 item.scalar_at(0),
244 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
245 );
246 }
247}