vortex_array/expr/exprs/get_item/
mod.rs1pub mod transform;
5
6use std::fmt::Formatter;
7use std::ops::Not;
8
9use prost::Message;
10use vortex_dtype::DType;
11use vortex_dtype::FieldName;
12use vortex_dtype::FieldPath;
13use vortex_dtype::Nullability;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18use vortex_proto::expr as pb;
19use vortex_vector::Vector;
20use vortex_vector::VectorOps;
21
22use crate::ArrayRef;
23use crate::ToCanonical;
24use crate::compute::mask;
25use crate::expr::ChildName;
26use crate::expr::ExecutionArgs;
27use crate::expr::ExprId;
28use crate::expr::Expression;
29use crate::expr::ExpressionView;
30use crate::expr::StatsCatalog;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::expr::exprs::root::root;
34use crate::expr::stats::Stat;
35
36pub struct GetItem;
37
38impl VTable for GetItem {
39 type Instance = FieldName;
40
41 fn id(&self) -> ExprId {
42 ExprId::from("vortex.get_item")
43 }
44
45 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
46 Ok(Some(
47 pb::GetItemOpts {
48 path: instance.to_string(),
49 }
50 .encode_to_vec(),
51 ))
52 }
53
54 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
55 let opts = pb::GetItemOpts::decode(metadata)?;
56 Ok(Some(FieldName::from(opts.path)))
57 }
58
59 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
60 if expr.children().len() != 1 {
61 vortex_bail!(
62 "GetItem expression requires exactly 1 child, got {}",
63 expr.children().len()
64 );
65 }
66 Ok(())
67 }
68
69 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
70 match child_idx {
71 0 => ChildName::from("input"),
72 _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
73 }
74 }
75
76 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
77 expr.children()[0].fmt_sql(f)?;
78 write!(f, ".{}", expr.data())
79 }
80
81 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
82 write!(f, "\"{}\"", instance)
83 }
84
85 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
86 let struct_dtype = expr.children()[0].return_dtype(scope)?;
87 let field_dtype = struct_dtype
88 .as_struct_fields_opt()
89 .and_then(|st| st.field(expr.data()))
90 .ok_or_else(|| {
91 vortex_err!("Couldn't find the {} field in the input scope", expr.data())
92 })?;
93
94 if matches!(
96 (struct_dtype.nullability(), field_dtype.nullability()),
97 (Nullability::Nullable, Nullability::NonNullable)
98 ) {
99 return Ok(field_dtype.with_nullability(Nullability::Nullable));
100 }
101
102 Ok(field_dtype)
103 }
104
105 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
106 let input = expr.children()[0].evaluate(scope)?.to_struct();
107 let field = input.field_by_name(expr.data()).cloned()?;
108
109 match input.dtype().nullability() {
110 Nullability::NonNullable => Ok(field),
111 Nullability::Nullable => mask(&field, &input.validity_mask().not()),
112 }
113 }
114
115 fn stat_expression(
116 &self,
117 expr: &ExpressionView<Self>,
118 stat: Stat,
119 catalog: &dyn StatsCatalog,
120 ) -> Option<Expression> {
121 catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), stat)
130 }
131
132 fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<Vector> {
133 let struct_dtype = args.dtypes[0]
134 .as_struct_fields_opt()
135 .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?;
136 let field_idx = struct_dtype
137 .find(field_name)
138 .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?;
139
140 let struct_vector = args
141 .vectors
142 .pop()
143 .vortex_expect("missing input")
144 .into_struct();
145
146 let mut field = struct_vector.fields()[field_idx].clone();
148 field.mask_validity(struct_vector.validity());
149
150 Ok(field)
151 }
152
153 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
155 true
156 }
157
158 fn is_fallible(&self, _instance: &Self::Instance) -> bool {
159 false
161 }
162}
163
164pub fn col(field: impl Into<FieldName>) -> Expression {
173 GetItem.new_expr(field.into(), vec![root()])
174}
175
176pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
185 GetItem.new_expr(field.into(), vec![child])
186}
187
188#[cfg(test)]
189mod tests {
190 use vortex_buffer::buffer;
191 use vortex_dtype::DType;
192 use vortex_dtype::FieldNames;
193 use vortex_dtype::Nullability;
194 use vortex_dtype::PType::I32;
195 use vortex_scalar::Scalar;
196
197 use super::get_item;
198 use crate::Array;
199 use crate::IntoArray;
200 use crate::arrays::StructArray;
201 use crate::expr::exprs::root::root;
202 use crate::validity::Validity;
203
204 fn test_array() -> StructArray {
205 StructArray::from_fields(&[
206 ("a", buffer![0i32, 1, 2].into_array()),
207 ("b", buffer![4i64, 5, 6].into_array()),
208 ])
209 .unwrap()
210 }
211
212 #[test]
213 fn get_item_by_name() {
214 let st = test_array();
215 let get_item = get_item("a", root());
216 let item = get_item.evaluate(&st.to_array()).unwrap();
217 assert_eq!(item.dtype(), &DType::from(I32))
218 }
219
220 #[test]
221 fn get_item_by_name_none() {
222 let st = test_array();
223 let get_item = get_item("c", root());
224 assert!(get_item.evaluate(&st.to_array()).is_err());
225 }
226
227 #[test]
228 fn get_nullable_field() {
229 let st = StructArray::try_new(
230 FieldNames::from(["a"]),
231 vec![buffer![1i32].into_array()],
232 1,
233 Validity::AllInvalid,
234 )
235 .unwrap()
236 .to_array();
237
238 let get_item = get_item("a", root());
239 let item = get_item.evaluate(&st).unwrap();
240 assert_eq!(
241 item.scalar_at(0),
242 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
243 );
244 }
245}