vortex_expr/exprs/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6
7use vortex_array::stats::Stat;
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, ToCanonical};
9use vortex_dtype::{DType, FieldName, FieldPath};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, root,
15    vtable,
16};
17
18vtable!(GetItem);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct GetItemExpr {
23    field: FieldName,
24    child: ExprRef,
25}
26
27impl PartialEq for GetItemExpr {
28    fn eq(&self, other: &Self) -> bool {
29        self.field == other.field && self.child.eq(&other.child)
30    }
31}
32
33pub struct GetItemExprEncoding;
34
35impl VTable for GetItemVTable {
36    type Expr = GetItemExpr;
37    type Encoding = GetItemExprEncoding;
38    type Metadata = ProstMetadata<pb::GetItemOpts>;
39
40    fn id(_encoding: &Self::Encoding) -> ExprId {
41        ExprId::new_ref("get_item")
42    }
43
44    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
45        ExprEncodingRef::new_ref(GetItemExprEncoding.as_ref())
46    }
47
48    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
49        Some(ProstMetadata(pb::GetItemOpts {
50            path: expr.field.to_string(),
51        }))
52    }
53
54    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
55        vec![&expr.child]
56    }
57
58    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
59        Ok(GetItemExpr {
60            field: expr.field.clone(),
61            child: children[0].clone(),
62        })
63    }
64
65    fn build(
66        _encoding: &Self::Encoding,
67        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
68        children: Vec<ExprRef>,
69    ) -> VortexResult<Self::Expr> {
70        if children.len() != 1 {
71            vortex_bail!(
72                "GetItem expression must have exactly 1 child, got {}",
73                children.len()
74            );
75        }
76
77        let field = FieldName::from(metadata.path.clone());
78        Ok(GetItemExpr {
79            field,
80            child: children[0].clone(),
81        })
82    }
83
84    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
85        expr.child
86            .unchecked_evaluate(scope)?
87            .to_struct()?
88            .field_by_name(expr.field())
89            .cloned()
90    }
91
92    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
93        let input = expr.child.return_dtype(scope)?;
94        input
95            .as_struct()
96            .and_then(|st| st.field(expr.field()))
97            .ok_or_else(|| {
98                vortex_err!(
99                    "Couldn't find the {} field in the input scope",
100                    expr.field()
101                )
102            })
103    }
104}
105
106impl GetItemExpr {
107    pub fn new(field: impl Into<FieldName>, child: ExprRef) -> Self {
108        Self {
109            field: field.into(),
110            child,
111        }
112    }
113
114    pub fn new_expr(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
115        Self::new(field, child).into_expr()
116    }
117
118    pub fn field(&self) -> &FieldName {
119        &self.field
120    }
121
122    pub fn child(&self) -> &ExprRef {
123        &self.child
124    }
125
126    pub fn is(expr: &ExprRef) -> bool {
127        expr.is::<GetItemVTable>()
128    }
129}
130
131/// Creates an expression that accesses a field from the root array.
132///
133/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
134///
135/// ```rust
136/// # use vortex_expr::col;
137/// let expr = col("name");
138/// ```
139pub fn col(field: impl Into<FieldName>) -> ExprRef {
140    GetItemExpr::new(field, root()).into_expr()
141}
142
143/// Creates an expression that extracts a named field from a struct expression.
144///
145/// Accesses the specified field from the result of the child expression.
146///
147/// ```rust
148/// # use vortex_expr::{get_item, root};
149/// let expr = get_item("user_id", root());
150/// ```
151pub fn get_item(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
152    GetItemExpr::new(field, child).into_expr()
153}
154
155impl Display for GetItemExpr {
156    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157        write!(f, "{}.{}", self.child, &self.field)
158    }
159}
160
161impl AnalysisExpr for GetItemExpr {
162    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
163        catalog.stats_ref(&self.field_path()?, Stat::Max)
164    }
165
166    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
167        catalog.stats_ref(&self.field_path()?, Stat::Min)
168    }
169
170    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
171        catalog.stats_ref(&self.field_path()?, Stat::NaNCount)
172    }
173
174    fn field_path(&self) -> Option<FieldPath> {
175        self.child()
176            .field_path()
177            .map(|fp| fp.push(self.field.clone()))
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use vortex_array::IntoArray;
184    use vortex_array::arrays::StructArray;
185    use vortex_buffer::buffer;
186    use vortex_dtype::DType;
187    use vortex_dtype::PType::I32;
188
189    use crate::get_item::get_item;
190    use crate::{Scope, root};
191
192    fn test_array() -> StructArray {
193        StructArray::from_fields(&[
194            ("a", buffer![0i32, 1, 2].into_array()),
195            ("b", buffer![4i64, 5, 6].into_array()),
196        ])
197        .unwrap()
198    }
199
200    #[test]
201    pub fn get_item_by_name() {
202        let st = test_array();
203        let get_item = get_item("a", root());
204        let item = get_item.evaluate(&Scope::new(st.to_array())).unwrap();
205        assert_eq!(item.dtype(), &DType::from(I32))
206    }
207
208    #[test]
209    pub fn get_item_by_name_none() {
210        let st = test_array();
211        let get_item = get_item("c", root());
212        assert!(get_item.evaluate(&Scope::new(st.to_array())).is_err());
213    }
214}