vortex_expr/
cast.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::ArrayRef;
6use vortex_array::compute::cast as compute_cast;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult};
9
10use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
11
12#[derive(Debug, Eq, Hash)]
13#[allow(clippy::derived_hash_with_manual_eq)]
14pub struct Cast {
15    target: DType,
16    child: ExprRef,
17}
18
19impl Cast {
20    pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
21        Arc::new(Self { target, child })
22    }
23}
24
25impl PartialEq for Cast {
26    fn eq(&self, other: &Self) -> bool {
27        self.target.eq(&other.target) && self.child.eq(&other.child)
28    }
29}
30
31impl Display for Cast {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "cast({}, {})", self.child, self.target)
34    }
35}
36
37#[cfg(feature = "proto")]
38pub(crate) mod proto {
39    use vortex_dtype::DType;
40    use vortex_error::{VortexResult, vortex_bail, vortex_err};
41    use vortex_proto::expr::kind;
42    use vortex_proto::expr::kind::Kind;
43
44    use crate::cast::Cast;
45    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
46
47    pub(crate) struct CastSerde;
48
49    impl Id for CastSerde {
50        fn id(&self) -> &'static str {
51            "cast"
52        }
53    }
54
55    impl ExprDeserialize for CastSerde {
56        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
57            let Kind::Cast(kind::Cast { target }) = kind else {
58                vortex_bail!("wrong kind {:?}, want cast", kind)
59            };
60            let target: DType = target
61                .as_ref()
62                .ok_or_else(|| vortex_err!("empty target dtype"))?
63                .try_into()?;
64
65            Ok(Cast::new_expr(children[0].clone(), target))
66        }
67    }
68
69    impl ExprSerializable for Cast {
70        fn id(&self) -> &'static str {
71            CastSerde.id()
72        }
73
74        fn serialize_kind(&self) -> VortexResult<Kind> {
75            Ok(Kind::Cast(kind::Cast {
76                target: Some((&self.target).into()),
77            }))
78        }
79    }
80}
81
82impl AnalysisExpr for Cast {}
83
84impl VortexExpr for Cast {
85    fn as_any(&self) -> &dyn Any {
86        self
87    }
88
89    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
90        let array = self.child.evaluate(scope)?;
91        compute_cast(&array, &self.target)
92    }
93
94    fn children(&self) -> Vec<&ExprRef> {
95        vec![&self.child]
96    }
97
98    fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
99        Self::new_expr(
100            children
101                .pop()
102                .vortex_expect("Cast::replacing_children should have one child"),
103            self.target.clone(),
104        )
105    }
106
107    fn return_dtype(&self, _scope_dtype: &ScopeDType) -> VortexResult<DType> {
108        Ok(self.target.clone())
109    }
110}
111
112pub fn cast(child: ExprRef, target: DType) -> ExprRef {
113    Cast::new_expr(child, target)
114}
115
116#[cfg(test)]
117mod tests {
118    use vortex_array::IntoArray;
119    use vortex_array::arrays::StructArray;
120    use vortex_buffer::buffer;
121    use vortex_dtype::{DType, Nullability, PType};
122
123    use crate::{ExprRef, Scope, ScopeDType, cast, get_item, root, test_harness};
124
125    #[test]
126    fn dtype() {
127        let dtype = test_harness::struct_dtype();
128        assert_eq!(
129            cast(root(), DType::Bool(Nullability::NonNullable))
130                .return_dtype(&ScopeDType::new(dtype))
131                .unwrap(),
132            DType::Bool(Nullability::NonNullable)
133        );
134    }
135
136    #[test]
137    fn replace_children() {
138        let expr = cast(root(), DType::Bool(Nullability::Nullable));
139        let _ = expr.replacing_children(vec![root()]);
140    }
141
142    #[test]
143    fn evaluate() {
144        let test_array = StructArray::from_fields(&[
145            ("a", buffer![0i32, 1, 2].into_array()),
146            ("b", buffer![4i64, 5, 6].into_array()),
147        ])
148        .unwrap()
149        .into_array();
150
151        let expr: ExprRef = cast(
152            get_item("a", root()),
153            DType::Primitive(PType::I64, Nullability::NonNullable),
154        );
155        let result = expr.evaluate(&Scope::new(test_array)).unwrap();
156
157        assert_eq!(
158            result.dtype(),
159            &DType::Primitive(PType::I64, Nullability::NonNullable)
160        );
161    }
162}