vortex_expr/exprs/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6use std::ops::Not;
7
8use vortex_array::compute::mask;
9use vortex_array::stats::Stat;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, ToCanonical};
11use vortex_dtype::{DType, FieldName, FieldPath, Nullability};
12use vortex_error::{VortexResult, vortex_bail, vortex_err};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, root,
18    vtable,
19};
20
21vtable!(GetItem);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct GetItemExpr {
26    field: FieldName,
27    child: ExprRef,
28}
29
30impl PartialEq for GetItemExpr {
31    fn eq(&self, other: &Self) -> bool {
32        self.field == other.field && self.child.eq(&other.child)
33    }
34}
35
36pub struct GetItemExprEncoding;
37
38impl VTable for GetItemVTable {
39    type Expr = GetItemExpr;
40    type Encoding = GetItemExprEncoding;
41    type Metadata = ProstMetadata<pb::GetItemOpts>;
42
43    fn id(_encoding: &Self::Encoding) -> ExprId {
44        ExprId::new_ref("get_item")
45    }
46
47    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
48        ExprEncodingRef::new_ref(GetItemExprEncoding.as_ref())
49    }
50
51    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
52        Some(ProstMetadata(pb::GetItemOpts {
53            path: expr.field.to_string(),
54        }))
55    }
56
57    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
58        vec![&expr.child]
59    }
60
61    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
62        Ok(GetItemExpr {
63            field: expr.field.clone(),
64            child: children[0].clone(),
65        })
66    }
67
68    fn build(
69        _encoding: &Self::Encoding,
70        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
71        children: Vec<ExprRef>,
72    ) -> VortexResult<Self::Expr> {
73        if children.len() != 1 {
74            vortex_bail!(
75                "GetItem expression must have exactly 1 child, got {}",
76                children.len()
77            );
78        }
79
80        let field = FieldName::from(metadata.path.clone());
81        Ok(GetItemExpr {
82            field,
83            child: children[0].clone(),
84        })
85    }
86
87    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
88        let input = expr.child.unchecked_evaluate(scope)?.to_struct();
89        let field = input.field_by_name(expr.field()).cloned()?;
90
91        match input.dtype().nullability() {
92            Nullability::NonNullable => Ok(field),
93            Nullability::Nullable => mask(&field, &input.validity_mask().not()),
94        }
95    }
96
97    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
98        let input = expr.child.return_dtype(scope)?;
99        input
100            .as_struct_fields_opt()
101            .and_then(|st| st.field(expr.field()))
102            .map(|f| f.union_nullability(input.nullability()))
103            .ok_or_else(|| {
104                vortex_err!(
105                    "Couldn't find the {} field in the input scope",
106                    expr.field()
107                )
108            })
109    }
110}
111
112impl GetItemExpr {
113    pub fn new(field: impl Into<FieldName>, child: ExprRef) -> Self {
114        Self {
115            field: field.into(),
116            child,
117        }
118    }
119
120    pub fn new_expr(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
121        Self::new(field, child).into_expr()
122    }
123
124    pub fn field(&self) -> &FieldName {
125        &self.field
126    }
127
128    pub fn child(&self) -> &ExprRef {
129        &self.child
130    }
131
132    pub fn is(expr: &ExprRef) -> bool {
133        expr.is::<GetItemVTable>()
134    }
135}
136
137/// Creates an expression that accesses a field from the root array.
138///
139/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
140///
141/// ```rust
142/// # use vortex_expr::col;
143/// let expr = col("name");
144/// ```
145pub fn col(field: impl Into<FieldName>) -> ExprRef {
146    GetItemExpr::new(field, root()).into_expr()
147}
148
149/// Creates an expression that extracts a named field from a struct expression.
150///
151/// Accesses the specified field from the result of the child expression.
152///
153/// ```rust
154/// # use vortex_expr::{get_item, root};
155/// let expr = get_item("user_id", root());
156/// ```
157pub fn get_item(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
158    GetItemExpr::new(field, child).into_expr()
159}
160
161impl DisplayAs for GetItemExpr {
162    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
163        match df {
164            DisplayFormat::Compact => {
165                write!(f, "{}.{}", self.child, &self.field)
166            }
167            DisplayFormat::Tree => {
168                write!(f, "GetItem({})", self.field)
169            }
170        }
171    }
172}
173impl AnalysisExpr for GetItemExpr {
174    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
175        catalog.stats_ref(&self.field_path()?, Stat::Max)
176    }
177
178    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
179        catalog.stats_ref(&self.field_path()?, Stat::Min)
180    }
181
182    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
183        catalog.stats_ref(&self.field_path()?, Stat::NaNCount)
184    }
185
186    fn field_path(&self) -> Option<FieldPath> {
187        self.child()
188            .field_path()
189            .map(|fp| fp.push(self.field.clone()))
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use vortex_array::arrays::{PrimitiveArray, StructArray};
196    use vortex_array::validity::Validity;
197    use vortex_array::{Array, IntoArray};
198    use vortex_buffer::buffer;
199    use vortex_dtype::PType::I32;
200    use vortex_dtype::{DType, FieldNames, Nullability};
201    use vortex_scalar::Scalar;
202
203    use crate::get_item::get_item;
204    use crate::{Scope, root};
205
206    fn test_array() -> StructArray {
207        StructArray::from_fields(&[
208            ("a", buffer![0i32, 1, 2].into_array()),
209            ("b", buffer![4i64, 5, 6].into_array()),
210        ])
211        .unwrap()
212    }
213
214    #[test]
215    fn get_item_by_name() {
216        let st = test_array();
217        let get_item = get_item("a", root());
218        let item = get_item.evaluate(&Scope::new(st.to_array())).unwrap();
219        assert_eq!(item.dtype(), &DType::from(I32))
220    }
221
222    #[test]
223    fn get_item_by_name_none() {
224        let st = test_array();
225        let get_item = get_item("c", root());
226        assert!(get_item.evaluate(&Scope::new(st.to_array())).is_err());
227    }
228
229    #[test]
230    fn get_nullable_field() {
231        let st = StructArray::try_new(
232            FieldNames::from(["a"]),
233            vec![PrimitiveArray::from_iter([1i32]).to_array()],
234            1,
235            Validity::AllInvalid,
236        )
237        .unwrap()
238        .to_array();
239
240        let get_item = get_item("a", root());
241        let item = get_item.evaluate(&Scope::new(st)).unwrap();
242        assert_eq!(
243            item.scalar_at(0),
244            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
245        );
246    }
247}