vortex_expr/exprs/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::cast as compute_cast;
5use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
6use vortex_dtype::DType;
7use vortex_error::{VortexResult, vortex_bail, vortex_err};
8use vortex_proto::expr as pb;
9
10use crate::display::{DisplayAs, DisplayFormat};
11use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
12
13vtable!(Cast);
14
15#[allow(clippy::derived_hash_with_manual_eq)]
16#[derive(Debug, Clone, Hash, Eq)]
17pub struct CastExpr {
18    target: DType,
19    child: ExprRef,
20}
21
22impl PartialEq for CastExpr {
23    fn eq(&self, other: &Self) -> bool {
24        self.target == other.target && self.child.eq(&other.child)
25    }
26}
27
28pub struct CastExprEncoding;
29
30impl VTable for CastVTable {
31    type Expr = CastExpr;
32    type Encoding = CastExprEncoding;
33    type Metadata = ProstMetadata<pb::CastOpts>;
34
35    fn id(_encoding: &Self::Encoding) -> ExprId {
36        ExprId::new_ref("cast")
37    }
38
39    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
40        ExprEncodingRef::new_ref(CastExprEncoding.as_ref())
41    }
42
43    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
44        Some(ProstMetadata(pb::CastOpts {
45            target: Some((&expr.target).into()),
46        }))
47    }
48
49    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
50        vec![&expr.child]
51    }
52
53    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
54        Ok(CastExpr {
55            target: expr.target.clone(),
56            child: children[0].clone(),
57        })
58    }
59
60    fn build(
61        _encoding: &Self::Encoding,
62        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
63        children: Vec<ExprRef>,
64    ) -> VortexResult<Self::Expr> {
65        if children.len() != 1 {
66            vortex_bail!(
67                "Cast expression must have exactly 1 child, got {}",
68                children.len()
69            );
70        }
71        let target: DType = metadata
72            .target
73            .as_ref()
74            .ok_or_else(|| vortex_err!("missing target dtype in CastOpts"))?
75            .try_into()?;
76        Ok(CastExpr {
77            target,
78            child: children[0].clone(),
79        })
80    }
81
82    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83        let array = expr.child.evaluate(scope)?;
84        compute_cast(&array, &expr.target).map_err(|e| {
85            e.with_context(format!(
86                "Failed to cast array of dtype {} to {}",
87                array.dtype(),
88                expr.target
89            ))
90        })
91    }
92
93    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
94        Ok(expr.target.clone())
95    }
96}
97
98impl CastExpr {
99    pub fn new(child: ExprRef, target: DType) -> Self {
100        Self { target, child }
101    }
102
103    pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
104        Self::new(child, target).into_expr()
105    }
106}
107
108impl DisplayAs for CastExpr {
109    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110        match df {
111            DisplayFormat::Compact => {
112                write!(f, "cast({}, {})", self.child, self.target)
113            }
114            DisplayFormat::Tree => {
115                write!(f, "Cast(target: {})", self.target)
116            }
117        }
118    }
119}
120
121impl AnalysisExpr for CastExpr {}
122
123/// Creates an expression that casts values to a target data type.
124///
125/// Converts the input expression's values to the specified target type.
126///
127/// ```rust
128/// # use vortex_dtype::{DType, Nullability, PType};
129/// # use vortex_expr::{cast, root};
130/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
131/// ```
132pub fn cast(child: ExprRef, target: DType) -> ExprRef {
133    CastExpr::new(child, target).into_expr()
134}
135
136#[cfg(test)]
137mod tests {
138    use vortex_array::IntoArray;
139    use vortex_array::arrays::StructArray;
140    use vortex_buffer::buffer;
141    use vortex_dtype::{DType, Nullability, PType};
142
143    use crate::{ExprRef, Scope, cast, get_item, root, test_harness};
144
145    #[test]
146    fn dtype() {
147        let dtype = test_harness::struct_dtype();
148        assert_eq!(
149            cast(root(), DType::Bool(Nullability::NonNullable))
150                .return_dtype(&dtype)
151                .unwrap(),
152            DType::Bool(Nullability::NonNullable)
153        );
154    }
155
156    #[test]
157    fn replace_children() {
158        let expr = cast(root(), DType::Bool(Nullability::Nullable));
159        let _ = expr.with_children(vec![root()]);
160    }
161
162    #[test]
163    fn evaluate() {
164        let test_array = StructArray::from_fields(&[
165            ("a", buffer![0i32, 1, 2].into_array()),
166            ("b", buffer![4i64, 5, 6].into_array()),
167        ])
168        .unwrap()
169        .into_array();
170
171        let expr: ExprRef = cast(
172            get_item("a", root()),
173            DType::Primitive(PType::I64, Nullability::NonNullable),
174        );
175        let result = expr.evaluate(&Scope::new(test_array)).unwrap();
176
177        assert_eq!(
178            result.dtype(),
179            &DType::Primitive(PType::I64, Nullability::NonNullable)
180        );
181    }
182
183    #[test]
184    fn test_display() {
185        let expr = cast(
186            get_item("value", root()),
187            DType::Primitive(PType::I64, Nullability::NonNullable),
188        );
189        assert_eq!(expr.to_string(), "cast($.value, i64)");
190
191        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
192        assert_eq!(expr2.to_string(), "cast($, bool?)");
193    }
194}