1use itertools::Itertools as _;
5use vortex_array::arrays::ChunkedArray;
6use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray};
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_bail};
9
10use crate::display::{DisplayAs, DisplayFormat};
11use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
12
13vtable!(Concat);
14
15#[allow(clippy::derived_hash_with_manual_eq)]
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct ConcatExpr {
38 values: Vec<ExprRef>,
39}
40
41pub struct ConcatExprEncoding;
42
43impl VTable for ConcatVTable {
44 type Expr = ConcatExpr;
45 type Encoding = ConcatExprEncoding;
46 type Metadata = EmptyMetadata;
47
48 fn id(_encoding: &Self::Encoding) -> ExprId {
49 ExprId::new_ref("concat")
50 }
51
52 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
53 ExprEncodingRef::new_ref(ConcatExprEncoding.as_ref())
54 }
55
56 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
57 Some(EmptyMetadata)
58 }
59
60 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
61 expr.values.iter().collect()
62 }
63
64 fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
65 Ok(ConcatExpr { values: children })
66 }
67
68 fn build(
69 _encoding: &Self::Encoding,
70 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
71 children: Vec<ExprRef>,
72 ) -> VortexResult<Self::Expr> {
73 Ok(ConcatExpr { values: children })
74 }
75
76 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
77 if expr.values.is_empty() {
78 vortex_bail!("Concat expression must have at least one child");
79 }
80
81 let value_arrays = expr
82 .values
83 .iter()
84 .map(|value_expr| value_expr.unchecked_evaluate(scope))
85 .process_results(|it| it.collect::<Vec<_>>())?;
86
87 let dtype = value_arrays[0].dtype().clone();
89
90 for array in &value_arrays[1..] {
92 if array.dtype() != &dtype {
93 vortex_bail!(
94 "All arrays in concat must have the same dtype, expected {:?} but got {:?}",
95 dtype,
96 array.dtype()
97 );
98 }
99 }
100
101 Ok(ChunkedArray::try_new(value_arrays, dtype)?.into_array())
102 }
103
104 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
105 if expr.values.is_empty() {
106 vortex_bail!("Concat expression must have at least one child");
107 }
108
109 let dtype = expr.values[0].return_dtype(scope)?;
111
112 for value_expr in &expr.values[1..] {
114 let child_dtype = value_expr.return_dtype(scope)?;
115 if child_dtype != dtype {
116 vortex_bail!(
117 "All expressions in concat must return the same dtype, expected {:?} but got {:?}",
118 dtype,
119 child_dtype
120 );
121 }
122 }
123
124 Ok(dtype)
125 }
126}
127
128impl ConcatExpr {
129 pub fn new(values: Vec<ExprRef>) -> Self {
130 ConcatExpr { values }
131 }
132
133 pub fn new_expr(values: Vec<ExprRef>) -> ExprRef {
134 Self::new(values).into_expr()
135 }
136
137 pub fn values(&self) -> &[ExprRef] {
138 &self.values
139 }
140}
141
142pub fn concat(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
152 let values = elements.into_iter().map(|value| value.into()).collect_vec();
153 ConcatExpr::new(values).into_expr()
154}
155
156impl DisplayAs for ConcatExpr {
157 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
158 match df {
159 DisplayFormat::Compact => {
160 write!(f, "concat({})", self.values.iter().format(", "))
161 }
162 DisplayFormat::Tree => {
163 write!(f, "Concat")
164 }
165 }
166 }
167}
168
169impl AnalysisExpr for ConcatExpr {}
170
171#[cfg(test)]
172mod tests {
173 use vortex_array::arrays::ChunkedVTable;
174 use vortex_array::{IntoArray, ToCanonical};
175 use vortex_buffer::buffer;
176 use vortex_dtype::{DType, Nullability, PType};
177
178 use crate::{ConcatExpr, Scope, col, concat, lit, root};
179
180 fn test_array() -> vortex_array::ArrayRef {
181 vortex_array::arrays::StructArray::from_fields(&[
182 ("a", buffer![1, 2, 3].into_array()),
183 ("b", buffer![4, 5, 6].into_array()),
184 ])
185 .unwrap()
186 .into_array()
187 }
188
189 #[test]
190 pub fn test_concat_literals() {
191 let expr = ConcatExpr::new(vec![
192 lit(vortex_scalar::Scalar::from(1i32)),
193 lit(vortex_scalar::Scalar::from(2i32)),
194 lit(vortex_scalar::Scalar::from(3i32)),
195 ]);
196
197 let scope_array = buffer![0i32].into_array();
199 let actual_array = expr.evaluate(&Scope::new(scope_array)).unwrap();
200
201 let chunked = actual_array.as_::<ChunkedVTable>();
202 assert_eq!(chunked.nchunks(), 3);
203 assert_eq!(chunked.len(), 3);
204
205 let canonical = chunked.to_canonical().into_array();
206 let primitive = canonical.to_primitive();
207 assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3]);
208 }
209
210 #[test]
211 pub fn test_concat_columns() {
212 let expr = ConcatExpr::new(vec![col("a"), col("b"), col("a")]);
213
214 let actual_array = expr.evaluate(&Scope::new(test_array())).unwrap();
215
216 let chunked = actual_array.as_::<ChunkedVTable>();
217 assert_eq!(chunked.nchunks(), 3);
218 assert_eq!(chunked.len(), 9);
219
220 let canonical = chunked.to_canonical().into_array();
221 let primitive = canonical.to_primitive();
222 assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 4, 5, 6, 1, 2, 3]);
223 }
224
225 #[test]
226 pub fn test_concat_mixed() {
227 let expr = ConcatExpr::new(vec![
228 col("a"),
229 lit(vortex_scalar::Scalar::from(99i32)),
230 col("b"),
231 ]);
232
233 let actual_array = expr.evaluate(&Scope::new(test_array())).unwrap();
234
235 let chunked = actual_array.as_::<ChunkedVTable>();
236 assert_eq!(chunked.nchunks(), 3);
237 assert_eq!(chunked.len(), 9);
239
240 let canonical = chunked.to_canonical().into_array();
241 let primitive = canonical.to_primitive();
242 assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 99, 99, 99, 4, 5, 6]);
243 }
244
245 #[test]
246 pub fn test_concat_dtype_mismatch() {
247 let expr = ConcatExpr::new(vec![
248 lit(vortex_scalar::Scalar::from(1i32)),
249 lit(vortex_scalar::Scalar::from(2i64)),
250 ]);
251
252 let result = expr.evaluate(&Scope::new(test_array()));
253 assert!(result.is_err());
254 }
255
256 #[test]
257 pub fn test_return_dtype() {
258 let expr = ConcatExpr::new(vec![
259 lit(vortex_scalar::Scalar::from(1i32)),
260 lit(vortex_scalar::Scalar::from(2i32)),
261 ]);
262
263 let dtype = expr
264 .return_dtype(&DType::Primitive(PType::I32, Nullability::NonNullable))
265 .unwrap();
266
267 assert_eq!(
268 dtype,
269 DType::Primitive(PType::I32, Nullability::NonNullable)
270 );
271 }
272
273 #[test]
274 pub fn test_display() {
275 let expr = concat([col("a"), col("b"), col("c")]);
276 assert_eq!(expr.to_string(), "concat($.a, $.b, $.c)");
277 }
278
279 #[test]
280 pub fn test_concat_with_root() {
281 let expr = concat([root(), root()]);
282
283 let test_array = buffer![1, 2, 3].into_array();
284 let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
285
286 let chunked = actual_array.as_::<ChunkedVTable>();
287 assert_eq!(chunked.nchunks(), 2);
288 assert_eq!(chunked.len(), 6);
289
290 let canonical = chunked.to_canonical().into_array();
291 let primitive = canonical.to_primitive();
292 assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 1, 2, 3]);
293 }
294}