vortex_array/expr/exprs/
get_item.rs1use std::fmt::Formatter;
5use std::ops::Not;
6
7use prost::Message;
8use vortex_dtype::{DType, FieldName, FieldPath, Nullability};
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11
12use crate::compute::mask;
13use crate::expr::exprs::root::root;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
15use crate::stats::Stat;
16use crate::{ArrayRef, ToCanonical};
17
18pub struct GetItem;
19
20impl VTable for GetItem {
21 type Instance = FieldName;
22
23 fn id(&self) -> ExprId {
24 ExprId::from("vortex.get_item")
25 }
26
27 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
28 Ok(Some(
29 pb::GetItemOpts {
30 path: instance.to_string(),
31 }
32 .encode_to_vec(),
33 ))
34 }
35
36 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
37 let opts = pb::GetItemOpts::decode(metadata)?;
38 Ok(Some(FieldName::from(opts.path)))
39 }
40
41 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
42 if expr.children().len() != 1 {
43 vortex_bail!(
44 "GetItem expression requires exactly 1 child, got {}",
45 expr.children().len()
46 );
47 }
48 Ok(())
49 }
50
51 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
52 match child_idx {
53 0 => ChildName::from("input"),
54 _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
55 }
56 }
57
58 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
59 expr.children()[0].fmt_sql(f)?;
60 write!(f, ".{}", expr.data())
61 }
62
63 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
64 write!(f, "\"{}\"", instance.inner().as_ref())
65 }
66
67 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
68 let struct_dtype = expr.children()[0].return_dtype(scope)?;
69 let field_dtype = struct_dtype
70 .as_struct_fields_opt()
71 .and_then(|st| st.field(expr.data()))
72 .ok_or_else(|| {
73 vortex_err!("Couldn't find the {} field in the input scope", expr.data())
74 })?;
75
76 if matches!(
78 (struct_dtype.nullability(), field_dtype.nullability()),
79 (Nullability::Nullable, Nullability::NonNullable)
80 ) {
81 return Ok(field_dtype.with_nullability(Nullability::Nullable));
82 }
83
84 Ok(field_dtype)
85 }
86
87 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
88 let input = expr.children()[0].evaluate(scope)?.to_struct();
89 let field = input.field_by_name(expr.data()).cloned()?;
90
91 match input.dtype().nullability() {
92 Nullability::NonNullable => Ok(field),
93 Nullability::Nullable => mask(&field, &input.validity_mask().not()),
94 }
95 }
96
97 fn stat_max(
98 &self,
99 expr: &ExpressionView<Self>,
100 catalog: &mut dyn StatsCatalog,
101 ) -> Option<Expression> {
102 catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), Stat::Max)
103 }
104
105 fn stat_min(
106 &self,
107 expr: &ExpressionView<Self>,
108 catalog: &mut dyn StatsCatalog,
109 ) -> Option<Expression> {
110 catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), Stat::Min)
111 }
112
113 fn stat_nan_count(
114 &self,
115 expr: &ExpressionView<Self>,
116 catalog: &mut dyn StatsCatalog,
117 ) -> Option<Expression> {
118 catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), Stat::NaNCount)
119 }
120
121 fn stat_field_path(&self, expr: &ExpressionView<Self>) -> Option<FieldPath> {
122 expr.children()[0]
123 .stat_field_path()
124 .map(|fp| fp.push(expr.data().clone()))
125 }
126}
127
128pub fn col(field: impl Into<FieldName>) -> Expression {
137 GetItem.new_expr(field.into(), vec![root()])
138}
139
140pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
149 GetItem.new_expr(field.into(), vec![child])
150}
151
152#[cfg(test)]
153mod tests {
154 use vortex_buffer::buffer;
155 use vortex_dtype::PType::I32;
156 use vortex_dtype::{DType, FieldNames, Nullability};
157 use vortex_scalar::Scalar;
158
159 use super::get_item;
160 use crate::arrays::StructArray;
161 use crate::expr::exprs::root::root;
162 use crate::validity::Validity;
163 use crate::{Array, IntoArray};
164
165 fn test_array() -> StructArray {
166 StructArray::from_fields(&[
167 ("a", buffer![0i32, 1, 2].into_array()),
168 ("b", buffer![4i64, 5, 6].into_array()),
169 ])
170 .unwrap()
171 }
172
173 #[test]
174 fn get_item_by_name() {
175 let st = test_array();
176 let get_item = get_item("a", root());
177 let item = get_item.evaluate(&st.to_array()).unwrap();
178 assert_eq!(item.dtype(), &DType::from(I32))
179 }
180
181 #[test]
182 fn get_item_by_name_none() {
183 let st = test_array();
184 let get_item = get_item("c", root());
185 assert!(get_item.evaluate(&st.to_array()).is_err());
186 }
187
188 #[test]
189 fn get_nullable_field() {
190 let st = StructArray::try_new(
191 FieldNames::from(["a"]),
192 vec![buffer![1i32].into_array()],
193 1,
194 Validity::AllInvalid,
195 )
196 .unwrap()
197 .to_array();
198
199 let get_item = get_item("a", root());
200 let item = get_item.evaluate(&st).unwrap();
201 assert_eq!(
202 item.scalar_at(0),
203 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
204 );
205 }
206}