vortex_array/expr/exprs/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Not;
6
7use prost::Message;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldPath;
11use vortex_dtype::Nullability;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_err;
15use vortex_proto::expr as pb;
16use vortex_vector::Datum;
17use vortex_vector::ScalarOps;
18use vortex_vector::VectorOps;
19
20use crate::ArrayRef;
21use crate::ToCanonical;
22use crate::builtins::ExprBuiltins;
23use crate::compute::mask;
24use crate::expr::Arity;
25use crate::expr::ChildName;
26use crate::expr::EmptyOptions;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Expression;
30use crate::expr::Literal;
31use crate::expr::Mask;
32use crate::expr::Pack;
33use crate::expr::ReduceCtx;
34use crate::expr::ReduceNode;
35use crate::expr::ReduceNodeRef;
36use crate::expr::StatsCatalog;
37use crate::expr::VTable;
38use crate::expr::VTableExt;
39use crate::expr::exprs::root::root;
40use crate::expr::lit;
41use crate::expr::stats::Stat;
42
43pub struct GetItem;
44
45impl VTable for GetItem {
46    type Options = FieldName;
47
48    fn id(&self) -> ExprId {
49        ExprId::from("vortex.get_item")
50    }
51
52    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        Ok(Some(
54            pb::GetItemOpts {
55                path: instance.to_string(),
56            }
57            .encode_to_vec(),
58        ))
59    }
60
61    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
62        let opts = pb::GetItemOpts::decode(metadata)?;
63        Ok(FieldName::from(opts.path))
64    }
65
66    fn arity(&self, _field_name: &FieldName) -> Arity {
67        Arity::Exact(1)
68    }
69
70    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
71        match child_idx {
72            0 => ChildName::from("input"),
73            _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
74        }
75    }
76
77    fn fmt_sql(
78        &self,
79        field_name: &FieldName,
80        expr: &Expression,
81        f: &mut Formatter<'_>,
82    ) -> std::fmt::Result {
83        expr.children()[0].fmt_sql(f)?;
84        write!(f, ".{}", field_name)
85    }
86
87    fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
88        let struct_dtype = &arg_dtypes[0];
89        let field_dtype = struct_dtype
90            .as_struct_fields_opt()
91            .and_then(|st| st.field(field_name))
92            .ok_or_else(|| {
93                vortex_err!("Couldn't find the {} field in the input scope", field_name)
94            })?;
95
96        // Match here to avoid cloning the dtype if nullability doesn't need to change
97        if matches!(
98            (struct_dtype.nullability(), field_dtype.nullability()),
99            (Nullability::Nullable, Nullability::NonNullable)
100        ) {
101            return Ok(field_dtype.with_nullability(Nullability::Nullable));
102        }
103
104        Ok(field_dtype)
105    }
106
107    fn evaluate(
108        &self,
109        field_name: &FieldName,
110        expr: &Expression,
111        scope: &ArrayRef,
112    ) -> VortexResult<ArrayRef> {
113        let input = expr.children()[0].evaluate(scope)?.to_struct();
114        let field = input.field_by_name(field_name).cloned()?;
115
116        match input.dtype().nullability() {
117            Nullability::NonNullable => Ok(field),
118            Nullability::Nullable => mask(&field, &input.validity_mask().not()),
119        }
120    }
121
122    fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<Datum> {
123        let struct_dtype = args.dtypes[0]
124            .as_struct_fields_opt()
125            .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?;
126        let field_idx = struct_dtype
127            .find(field_name)
128            .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?;
129
130        match args.datums.pop().vortex_expect("missing input") {
131            Datum::Scalar(s) => {
132                let mut field = s.as_struct().field(field_idx);
133                field.mask_validity(s.is_valid());
134                Ok(Datum::Scalar(field))
135            }
136            Datum::Vector(v) => {
137                let mut field = v.as_struct().fields()[field_idx].clone();
138                field.mask_validity(v.validity());
139                Ok(Datum::Vector(field))
140            }
141        }
142    }
143
144    fn reduce(
145        &self,
146        field_name: &FieldName,
147        node: &dyn ReduceNode,
148        ctx: &dyn ReduceCtx,
149    ) -> VortexResult<Option<ReduceNodeRef>> {
150        let child = node.child(0);
151        if let Some(child_fn) = child.scalar_fn()
152            && let Some(pack) = child_fn.as_opt::<Pack>()
153            && let Some(idx) = pack.names.find(field_name)
154        {
155            let mut field = child.child(idx);
156
157            // Possibly mask the field if the pack is nullable
158            if pack.nullability.is_nullable() {
159                field = ctx.new_node(
160                    Mask.bind(EmptyOptions),
161                    &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
162                )?;
163            }
164
165            return Ok(Some(field));
166        }
167
168        Ok(None)
169    }
170
171    fn simplify_untyped(
172        &self,
173        field_name: &FieldName,
174        expr: &Expression,
175    ) -> VortexResult<Option<Expression>> {
176        let child = expr.child(0);
177
178        // If the child is a Pack expression, we can directly return the corresponding child.
179        if let Some(pack) = child.as_opt::<Pack>() {
180            let idx = pack
181                .names
182                .iter()
183                .position(|name| name == field_name)
184                .ok_or_else(|| {
185                    vortex_err!(
186                        "Cannot find field {} in pack fields {:?}",
187                        field_name,
188                        pack.names
189                    )
190                })?;
191
192            let mut field = child.child(idx).clone();
193
194            // It's useful to simplify this node without type info, but we need to make sure
195            // the nullability is correct. We cannot cast since we don't have the dtype info here,
196            // so instead we insert a Mask expression that we know converts a child's dtype to
197            // nullable.
198            if pack.nullability.is_nullable() {
199                // Mask with an all-true array to ensure the field DType is nullable.
200                field = field.mask(lit(true))?;
201            }
202
203            return Ok(Some(field));
204        }
205
206        Ok(None)
207    }
208
209    fn stat_expression(
210        &self,
211        field_name: &FieldName,
212        _expr: &Expression,
213        stat: Stat,
214        catalog: &dyn StatsCatalog,
215    ) -> Option<Expression> {
216        // TODO(ngates): I think we can do better here and support stats over nested fields.
217        //  It would be nice if delegating to our child would return a struct of statistics
218        //  matching the nested DType such that we can write:
219        //    `get_item(expr.child(0).stat_expression(...), expr.data().field_name())`
220
221        // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same
222        //  name as a field in the root struct. This should be resolved with upcoming change to
223        //  falsify expressions, but for now I'm preserving the existing buggy behavior.
224        catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
225    }
226
227    // This will apply struct nullability field. We could add a dtype??
228    fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
229        true
230    }
231
232    fn is_fallible(&self, _field_name: &FieldName) -> bool {
233        // If this type-checks its infallible.
234        false
235    }
236}
237
238/// Creates an expression that accesses a field from the root array.
239///
240/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
241///
242/// ```rust
243/// # use vortex_array::expr::col;
244/// let expr = col("name");
245/// ```
246pub fn col(field: impl Into<FieldName>) -> Expression {
247    GetItem.new_expr(field.into(), vec![root()])
248}
249
250/// Creates an expression that extracts a named field from a struct expression.
251///
252/// Accesses the specified field from the result of the child expression.
253///
254/// ```rust
255/// # use vortex_array::expr::{get_item, root};
256/// let expr = get_item("user_id", root());
257/// ```
258pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
259    GetItem.new_expr(field.into(), vec![child])
260}
261
262#[cfg(test)]
263mod tests {
264    use vortex_buffer::buffer;
265    use vortex_dtype::DType;
266    use vortex_dtype::FieldNames;
267    use vortex_dtype::Nullability;
268    use vortex_dtype::Nullability::NonNullable;
269    use vortex_dtype::PType;
270    use vortex_dtype::StructFields;
271    use vortex_scalar::Scalar;
272
273    use crate::Array;
274    use crate::IntoArray;
275    use crate::arrays::StructArray;
276    use crate::expr::exprs::binary::checked_add;
277    use crate::expr::exprs::get_item::get_item;
278    use crate::expr::exprs::literal::lit;
279    use crate::expr::exprs::pack::pack;
280    use crate::expr::exprs::root::root;
281    use crate::validity::Validity;
282
283    fn test_array() -> StructArray {
284        StructArray::from_fields(&[
285            ("a", buffer![0i32, 1, 2].into_array()),
286            ("b", buffer![4i64, 5, 6].into_array()),
287        ])
288        .unwrap()
289    }
290
291    #[test]
292    fn get_item_by_name() {
293        let st = test_array();
294        let get_item = get_item("a", root());
295        let item = get_item.evaluate(&st.to_array()).unwrap();
296        assert_eq!(item.dtype(), &DType::from(PType::I32))
297    }
298
299    #[test]
300    fn get_item_by_name_none() {
301        let st = test_array();
302        let get_item = get_item("c", root());
303        assert!(get_item.evaluate(&st.to_array()).is_err());
304    }
305
306    #[test]
307    fn get_nullable_field() {
308        let st = StructArray::try_new(
309            FieldNames::from(["a"]),
310            vec![buffer![1i32].into_array()],
311            1,
312            Validity::AllInvalid,
313        )
314        .unwrap()
315        .to_array();
316
317        let get_item = get_item("a", root());
318        let item = get_item.evaluate(&st).unwrap();
319        assert_eq!(
320            item.scalar_at(0),
321            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
322        );
323    }
324
325    #[test]
326    fn test_pack_get_item_rule() {
327        // Create: pack(a: lit(1), b: lit(2)).get_item("b")
328        let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
329        let get_item_expr = get_item("b", pack_expr);
330
331        let result = get_item_expr
332            .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
333            .unwrap();
334
335        assert_eq!(result, lit(2));
336    }
337
338    #[test]
339    fn test_multi_level_pack_get_item_simplify() {
340        let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
341        let get_a = get_item("a", inner_pack);
342
343        let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
344        let get_z = get_item("z", outer_pack);
345
346        let dtype = DType::Primitive(PType::I32, NonNullable);
347
348        let result = get_z.optimize_recursive(&dtype).unwrap();
349        assert_eq!(result, lit(4));
350    }
351
352    #[test]
353    fn test_deeply_nested_pack_get_item() {
354        let innermost = pack([("a", lit(42))], NonNullable);
355        let get_a = get_item("a", innermost);
356
357        let level2 = pack([("b", get_a)], NonNullable);
358        let get_b = get_item("b", level2);
359
360        let level3 = pack([("c", get_b)], NonNullable);
361        let get_c = get_item("c", level3);
362
363        let outermost = pack([("final", get_c)], NonNullable);
364        let get_final = get_item("final", outermost);
365
366        let dtype = DType::Primitive(PType::I32, NonNullable);
367
368        let result = get_final.optimize_recursive(&dtype).unwrap();
369        assert_eq!(result, lit(42));
370    }
371
372    #[test]
373    fn test_partial_pack_get_item_simplify() {
374        let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
375        let get_x = get_item("x", inner_pack);
376        let add_expr = checked_add(get_x, lit(10));
377
378        let outer_pack = pack([("result", add_expr)], NonNullable);
379        let get_result = get_item("result", outer_pack);
380
381        let dtype = DType::Primitive(PType::I32, NonNullable);
382
383        let result = get_result.optimize_recursive(&dtype).unwrap();
384        let expected = checked_add(lit(1), lit(10));
385        assert_eq!(&result, &expected);
386    }
387}