vortex_array/expr/exprs/
cast.rs1use std::fmt::Formatter;
5use std::ops::Deref;
6
7use prost::Message;
8use vortex_dtype::{DType, FieldPath};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11
12use crate::ArrayRef;
13use crate::compute::cast as compute_cast;
14use crate::expr::expression::Expression;
15use crate::expr::{ChildName, ExprId, ExpressionView, StatsCatalog, VTable, VTableExt};
16
17pub struct Cast;
19
20impl VTable for Cast {
21 type Instance = DType;
22
23 fn id(&self) -> ExprId {
24 ExprId::from("vortex.cast")
25 }
26
27 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
28 Ok(Some(
29 pb::CastOpts {
30 target: Some(instance.into()),
31 }
32 .encode_to_vec(),
33 ))
34 }
35
36 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
37 Ok(Some(
38 pb::CastOpts::decode(metadata)?
39 .target
40 .as_ref()
41 .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
42 .try_into()?,
43 ))
44 }
45
46 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
47 if expr.children().len() != 1 {
48 vortex_bail!(
49 "Cast expression requires exactly 1 child, got {}",
50 expr.children().len()
51 );
52 }
53 Ok(())
54 }
55
56 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
57 match child_idx {
58 0 => ChildName::from("input"),
59 _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
60 }
61 }
62
63 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
64 write!(f, "cast(")?;
65 expr.children()[0].fmt_sql(f)?;
66 write!(f, " as {}", expr.data())?;
67 write!(f, ")")
68 }
69
70 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{}", instance)
72 }
73
74 fn return_dtype(&self, expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
75 Ok(expr.data().clone())
76 }
77
78 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
79 let array = expr.children()[0].evaluate(scope)?;
80 compute_cast(&array, expr.data()).map_err(|e| {
81 e.with_context(format!(
82 "Failed to cast array of dtype {} to {}",
83 array.dtype(),
84 expr.deref()
85 ))
86 })
87 }
88
89 fn stat_max(
90 &self,
91 expr: &ExpressionView<Self>,
92 catalog: &mut dyn StatsCatalog,
93 ) -> Option<Expression> {
94 expr.children()[0].stat_max(catalog)
95 }
96
97 fn stat_min(
98 &self,
99 expr: &ExpressionView<Self>,
100 catalog: &mut dyn StatsCatalog,
101 ) -> Option<Expression> {
102 expr.children()[0].stat_min(catalog)
103 }
104
105 fn stat_nan_count(
106 &self,
107 expr: &ExpressionView<Self>,
108 catalog: &mut dyn StatsCatalog,
109 ) -> Option<Expression> {
110 expr.children()[0].stat_nan_count(catalog)
111 }
112
113 fn stat_field_path(&self, expr: &ExpressionView<Self>) -> Option<FieldPath> {
114 expr.children()[0].stat_field_path()
115 }
116}
117
118pub fn cast(child: Expression, target: DType) -> Expression {
128 Cast.try_new_expr(target, [child])
129 .vortex_expect("Failed to create Cast expression")
130}
131
132#[cfg(test)]
133mod tests {
134 use vortex_buffer::buffer;
135 use vortex_dtype::{DType, Nullability, PType};
136 use vortex_error::VortexUnwrap as _;
137
138 use super::cast;
139 use crate::IntoArray;
140 use crate::arrays::StructArray;
141 use crate::expr::exprs::get_item::get_item;
142 use crate::expr::exprs::root::root;
143 use crate::expr::{Expression, test_harness};
144
145 #[test]
146 fn dtype() {
147 let dtype = test_harness::struct_dtype();
148 assert_eq!(
149 cast(root(), DType::Bool(Nullability::NonNullable))
150 .return_dtype(&dtype)
151 .unwrap(),
152 DType::Bool(Nullability::NonNullable)
153 );
154 }
155
156 #[test]
157 fn replace_children() {
158 let expr = cast(root(), DType::Bool(Nullability::Nullable));
159 expr.with_children(vec![root()]).vortex_unwrap();
160 }
161
162 #[test]
163 fn evaluate() {
164 let test_array = StructArray::from_fields(&[
165 ("a", buffer![0i32, 1, 2].into_array()),
166 ("b", buffer![4i64, 5, 6].into_array()),
167 ])
168 .unwrap()
169 .into_array();
170
171 let expr: Expression = cast(
172 get_item("a", root()),
173 DType::Primitive(PType::I64, Nullability::NonNullable),
174 );
175 let result = expr.evaluate(&test_array).unwrap();
176
177 assert_eq!(
178 result.dtype(),
179 &DType::Primitive(PType::I64, Nullability::NonNullable)
180 );
181 }
182
183 #[test]
184 fn test_display() {
185 let expr = cast(
186 get_item("value", root()),
187 DType::Primitive(PType::I64, Nullability::NonNullable),
188 );
189 assert_eq!(expr.to_string(), "cast($.value as i64)");
190
191 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
192 assert_eq!(expr2.to_string(), "cast($ as bool?)");
193 }
194}