vortex_expr/exprs/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::cast as compute_cast;
5use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
6use vortex_dtype::DType;
7use vortex_error::{VortexResult, vortex_bail, vortex_err};
8use vortex_proto::expr as pb;
9
10use crate::display::{DisplayAs, DisplayFormat};
11use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
12
13vtable!(Cast);
14
15#[allow(clippy::derived_hash_with_manual_eq)]
16#[derive(Debug, Clone, Hash, Eq)]
17pub struct CastExpr {
18    target: DType,
19    child: ExprRef,
20}
21
22impl PartialEq for CastExpr {
23    fn eq(&self, other: &Self) -> bool {
24        self.target == other.target && self.child.eq(&other.child)
25    }
26}
27
28pub struct CastExprEncoding;
29
30impl VTable for CastVTable {
31    type Expr = CastExpr;
32    type Encoding = CastExprEncoding;
33    type Metadata = ProstMetadata<pb::CastOpts>;
34
35    fn id(_encoding: &Self::Encoding) -> ExprId {
36        ExprId::new_ref("cast")
37    }
38
39    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
40        ExprEncodingRef::new_ref(CastExprEncoding.as_ref())
41    }
42
43    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
44        Some(ProstMetadata(pb::CastOpts {
45            target: Some((&expr.target).into()),
46        }))
47    }
48
49    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
50        vec![&expr.child]
51    }
52
53    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
54        Ok(CastExpr {
55            target: expr.target.clone(),
56            child: children[0].clone(),
57        })
58    }
59
60    fn build(
61        _encoding: &Self::Encoding,
62        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
63        children: Vec<ExprRef>,
64    ) -> VortexResult<Self::Expr> {
65        if children.len() != 1 {
66            vortex_bail!(
67                "Cast expression must have exactly 1 child, got {}",
68                children.len()
69            );
70        }
71        let target: DType = metadata
72            .target
73            .as_ref()
74            .ok_or_else(|| vortex_err!("missing target dtype in CastOpts"))?
75            .try_into()?;
76        Ok(CastExpr {
77            target,
78            child: children[0].clone(),
79        })
80    }
81
82    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83        let array = expr.child.evaluate(scope)?;
84        compute_cast(&array, &expr.target)
85    }
86
87    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
88        Ok(expr.target.clone())
89    }
90}
91
92impl CastExpr {
93    pub fn new(child: ExprRef, target: DType) -> Self {
94        Self { target, child }
95    }
96
97    pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
98        Self::new(child, target).into_expr()
99    }
100}
101
102impl DisplayAs for CastExpr {
103    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
104        match df {
105            DisplayFormat::Compact => {
106                write!(f, "cast({}, {})", self.child, self.target)
107            }
108            DisplayFormat::Tree => {
109                write!(f, "Cast(target: {})", self.target)
110            }
111        }
112    }
113}
114
115impl AnalysisExpr for CastExpr {}
116
117/// Creates an expression that casts values to a target data type.
118///
119/// Converts the input expression's values to the specified target type.
120///
121/// ```rust
122/// # use vortex_dtype::{DType, Nullability, PType};
123/// # use vortex_expr::{cast, root};
124/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
125/// ```
126pub fn cast(child: ExprRef, target: DType) -> ExprRef {
127    CastExpr::new(child, target).into_expr()
128}
129
130#[cfg(test)]
131mod tests {
132    use vortex_array::IntoArray;
133    use vortex_array::arrays::StructArray;
134    use vortex_buffer::buffer;
135    use vortex_dtype::{DType, Nullability, PType};
136
137    use crate::{ExprRef, Scope, cast, get_item, root, test_harness};
138
139    #[test]
140    fn dtype() {
141        let dtype = test_harness::struct_dtype();
142        assert_eq!(
143            cast(root(), DType::Bool(Nullability::NonNullable))
144                .return_dtype(&dtype)
145                .unwrap(),
146            DType::Bool(Nullability::NonNullable)
147        );
148    }
149
150    #[test]
151    fn replace_children() {
152        let expr = cast(root(), DType::Bool(Nullability::Nullable));
153        let _ = expr.with_children(vec![root()]);
154    }
155
156    #[test]
157    fn evaluate() {
158        let test_array = StructArray::from_fields(&[
159            ("a", buffer![0i32, 1, 2].into_array()),
160            ("b", buffer![4i64, 5, 6].into_array()),
161        ])
162        .unwrap()
163        .into_array();
164
165        let expr: ExprRef = cast(
166            get_item("a", root()),
167            DType::Primitive(PType::I64, Nullability::NonNullable),
168        );
169        let result = expr.evaluate(&Scope::new(test_array)).unwrap();
170
171        assert_eq!(
172            result.dtype(),
173            &DType::Primitive(PType::I64, Nullability::NonNullable)
174        );
175    }
176
177    #[test]
178    fn test_display() {
179        let expr = cast(
180            get_item("value", root()),
181            DType::Primitive(PType::I64, Nullability::NonNullable),
182        );
183        assert_eq!(expr.to_string(), "cast($.value, i64)");
184
185        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
186        assert_eq!(expr2.to_string(), "cast($, bool?)");
187    }
188}