vortex_expr/exprs/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use vortex_array::compute::cast as compute_cast;
7use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11
12use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
13
14vtable!(Cast);
15
16#[allow(clippy::derived_hash_with_manual_eq)]
17#[derive(Debug, Clone, Hash, Eq)]
18pub struct CastExpr {
19    target: DType,
20    child: ExprRef,
21}
22
23impl PartialEq for CastExpr {
24    fn eq(&self, other: &Self) -> bool {
25        self.target == other.target && self.child.eq(&other.child)
26    }
27}
28
29pub struct CastExprEncoding;
30
31impl VTable for CastVTable {
32    type Expr = CastExpr;
33    type Encoding = CastExprEncoding;
34    type Metadata = ProstMetadata<pb::CastOpts>;
35
36    fn id(_encoding: &Self::Encoding) -> ExprId {
37        ExprId::new_ref("cast")
38    }
39
40    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
41        ExprEncodingRef::new_ref(CastExprEncoding.as_ref())
42    }
43
44    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
45        Some(ProstMetadata(pb::CastOpts {
46            target: Some((&expr.target).into()),
47        }))
48    }
49
50    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
51        vec![&expr.child]
52    }
53
54    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
55        Ok(CastExpr {
56            target: expr.target.clone(),
57            child: children[0].clone(),
58        })
59    }
60
61    fn build(
62        _encoding: &Self::Encoding,
63        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
64        children: Vec<ExprRef>,
65    ) -> VortexResult<Self::Expr> {
66        if children.len() != 1 {
67            vortex_bail!(
68                "Cast expression must have exactly 1 child, got {}",
69                children.len()
70            );
71        }
72        let target: DType = metadata
73            .target
74            .as_ref()
75            .ok_or_else(|| vortex_err!("missing target dtype in CastOpts"))?
76            .try_into()?;
77        Ok(CastExpr {
78            target,
79            child: children[0].clone(),
80        })
81    }
82
83    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
84        let array = expr.child.evaluate(scope)?;
85        compute_cast(&array, &expr.target)
86    }
87
88    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
89        Ok(expr.target.clone())
90    }
91}
92
93impl CastExpr {
94    pub fn new(child: ExprRef, target: DType) -> Self {
95        Self { target, child }
96    }
97
98    pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
99        Self::new(child, target).into_expr()
100    }
101}
102
103impl Display for CastExpr {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(f, "cast({}, {})", self.child, self.target)
106    }
107}
108
109impl AnalysisExpr for CastExpr {}
110
111/// Creates an expression that casts values to a target data type.
112///
113/// Converts the input expression's values to the specified target type.
114///
115/// ```rust
116/// # use vortex_dtype::{DType, Nullability, PType};
117/// # use vortex_expr::{cast, root};
118/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
119/// ```
120pub fn cast(child: ExprRef, target: DType) -> ExprRef {
121    CastExpr::new(child, target).into_expr()
122}
123
124#[cfg(test)]
125mod tests {
126    use vortex_array::IntoArray;
127    use vortex_array::arrays::StructArray;
128    use vortex_buffer::buffer;
129    use vortex_dtype::{DType, Nullability, PType};
130
131    use crate::{ExprRef, Scope, cast, get_item, root, test_harness};
132
133    #[test]
134    fn dtype() {
135        let dtype = test_harness::struct_dtype();
136        assert_eq!(
137            cast(root(), DType::Bool(Nullability::NonNullable))
138                .return_dtype(&dtype)
139                .unwrap(),
140            DType::Bool(Nullability::NonNullable)
141        );
142    }
143
144    #[test]
145    fn replace_children() {
146        let expr = cast(root(), DType::Bool(Nullability::Nullable));
147        let _ = expr.with_children(vec![root()]);
148    }
149
150    #[test]
151    fn evaluate() {
152        let test_array = StructArray::from_fields(&[
153            ("a", buffer![0i32, 1, 2].into_array()),
154            ("b", buffer![4i64, 5, 6].into_array()),
155        ])
156        .unwrap()
157        .into_array();
158
159        let expr: ExprRef = cast(
160            get_item("a", root()),
161            DType::Primitive(PType::I64, Nullability::NonNullable),
162        );
163        let result = expr.evaluate(&Scope::new(test_array)).unwrap();
164
165        assert_eq!(
166            result.dtype(),
167            &DType::Primitive(PType::I64, Nullability::NonNullable)
168        );
169    }
170}