vortex_array/expr/exprs/get_item/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4pub mod transform;
5
6use std::fmt::Formatter;
7use std::ops::Not;
8
9use prost::Message;
10use vortex_dtype::DType;
11use vortex_dtype::FieldName;
12use vortex_dtype::FieldPath;
13use vortex_dtype::Nullability;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18use vortex_proto::expr as pb;
19use vortex_vector::Vector;
20use vortex_vector::VectorOps;
21
22use crate::ArrayRef;
23use crate::ToCanonical;
24use crate::compute::mask;
25use crate::expr::ChildName;
26use crate::expr::ExecutionArgs;
27use crate::expr::ExprId;
28use crate::expr::Expression;
29use crate::expr::ExpressionView;
30use crate::expr::StatsCatalog;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::expr::exprs::root::root;
34use crate::expr::stats::Stat;
35
36pub struct GetItem;
37
38impl VTable for GetItem {
39    type Instance = FieldName;
40
41    fn id(&self) -> ExprId {
42        ExprId::from("vortex.get_item")
43    }
44
45    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
46        Ok(Some(
47            pb::GetItemOpts {
48                path: instance.to_string(),
49            }
50            .encode_to_vec(),
51        ))
52    }
53
54    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
55        let opts = pb::GetItemOpts::decode(metadata)?;
56        Ok(Some(FieldName::from(opts.path)))
57    }
58
59    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
60        if expr.children().len() != 1 {
61            vortex_bail!(
62                "GetItem expression requires exactly 1 child, got {}",
63                expr.children().len()
64            );
65        }
66        Ok(())
67    }
68
69    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
70        match child_idx {
71            0 => ChildName::from("input"),
72            _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
73        }
74    }
75
76    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
77        expr.children()[0].fmt_sql(f)?;
78        write!(f, ".{}", expr.data())
79    }
80
81    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
82        write!(f, "\"{}\"", instance)
83    }
84
85    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
86        let struct_dtype = expr.children()[0].return_dtype(scope)?;
87        let field_dtype = struct_dtype
88            .as_struct_fields_opt()
89            .and_then(|st| st.field(expr.data()))
90            .ok_or_else(|| {
91                vortex_err!("Couldn't find the {} field in the input scope", expr.data())
92            })?;
93
94        // Match here to avoid cloning the dtype if nullability doesn't need to change
95        if matches!(
96            (struct_dtype.nullability(), field_dtype.nullability()),
97            (Nullability::Nullable, Nullability::NonNullable)
98        ) {
99            return Ok(field_dtype.with_nullability(Nullability::Nullable));
100        }
101
102        Ok(field_dtype)
103    }
104
105    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
106        let input = expr.children()[0].evaluate(scope)?.to_struct();
107        let field = input.field_by_name(expr.data()).cloned()?;
108
109        match input.dtype().nullability() {
110            Nullability::NonNullable => Ok(field),
111            Nullability::Nullable => mask(&field, &input.validity_mask().not()),
112        }
113    }
114
115    fn stat_expression(
116        &self,
117        expr: &ExpressionView<Self>,
118        stat: Stat,
119        catalog: &dyn StatsCatalog,
120    ) -> Option<Expression> {
121        // TODO(ngates): I think we can do better here and support stats over nested fields.
122        //  It would be nice if delegating to our child would return a struct of statistics
123        //  matching the nested DType such that we can write:
124        //    `get_item(expr.child(0).stat_expression(...), expr.data().field_name())`
125
126        // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same
127        //  name as a field in the root struct. This should be resolved with upcoming change to
128        //  falsify expressions, but for now I'm preserving the existing buggy behavior.
129        catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), stat)
130    }
131
132    fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<Vector> {
133        let struct_dtype = args.dtypes[0]
134            .as_struct_fields_opt()
135            .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?;
136        let field_idx = struct_dtype
137            .find(field_name)
138            .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?;
139
140        let struct_vector = args
141            .vectors
142            .pop()
143            .vortex_expect("missing input")
144            .into_struct();
145
146        // We must intersect the validity with that of the parent struct
147        let mut field = struct_vector.fields()[field_idx].clone();
148        field.mask_validity(struct_vector.validity());
149
150        Ok(field)
151    }
152
153    // This will apply struct nullability field. We could add a dtype??
154    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
155        true
156    }
157
158    fn is_fallible(&self, _instance: &Self::Instance) -> bool {
159        // If this type-checks its infallible.
160        false
161    }
162}
163
164/// Creates an expression that accesses a field from the root array.
165///
166/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
167///
168/// ```rust
169/// # use vortex_array::expr::col;
170/// let expr = col("name");
171/// ```
172pub fn col(field: impl Into<FieldName>) -> Expression {
173    GetItem.new_expr(field.into(), vec![root()])
174}
175
176/// Creates an expression that extracts a named field from a struct expression.
177///
178/// Accesses the specified field from the result of the child expression.
179///
180/// ```rust
181/// # use vortex_array::expr::{get_item, root};
182/// let expr = get_item("user_id", root());
183/// ```
184pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
185    GetItem.new_expr(field.into(), vec![child])
186}
187
188#[cfg(test)]
189mod tests {
190    use vortex_buffer::buffer;
191    use vortex_dtype::DType;
192    use vortex_dtype::FieldNames;
193    use vortex_dtype::Nullability;
194    use vortex_dtype::PType::I32;
195    use vortex_scalar::Scalar;
196
197    use super::get_item;
198    use crate::Array;
199    use crate::IntoArray;
200    use crate::arrays::StructArray;
201    use crate::expr::exprs::root::root;
202    use crate::validity::Validity;
203
204    fn test_array() -> StructArray {
205        StructArray::from_fields(&[
206            ("a", buffer![0i32, 1, 2].into_array()),
207            ("b", buffer![4i64, 5, 6].into_array()),
208        ])
209        .unwrap()
210    }
211
212    #[test]
213    fn get_item_by_name() {
214        let st = test_array();
215        let get_item = get_item("a", root());
216        let item = get_item.evaluate(&st.to_array()).unwrap();
217        assert_eq!(item.dtype(), &DType::from(I32))
218    }
219
220    #[test]
221    fn get_item_by_name_none() {
222        let st = test_array();
223        let get_item = get_item("c", root());
224        assert!(get_item.evaluate(&st.to_array()).is_err());
225    }
226
227    #[test]
228    fn get_nullable_field() {
229        let st = StructArray::try_new(
230            FieldNames::from(["a"]),
231            vec![buffer![1i32].into_array()],
232            1,
233            Validity::AllInvalid,
234        )
235        .unwrap()
236        .to_array();
237
238        let get_item = get_item("a", root());
239        let item = get_item.evaluate(&st).unwrap();
240        assert_eq!(
241            item.scalar_at(0),
242            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
243        );
244    }
245}