vortex_expr/exprs/
concat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use vortex_array::arrays::ChunkedArray;
6use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray};
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_bail};
9
10use crate::display::{DisplayAs, DisplayFormat};
11use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
12
13vtable!(Concat);
14
15/// Concatenate zero or more expressions into a single array.
16///
17/// All child expressions must evaluate to arrays of the same dtype.
18///
19/// # Examples
20///
21/// ```
22/// use vortex_array::{IntoArray, ToCanonical};
23/// use vortex_buffer::buffer;
24/// use vortex_expr::{ConcatExpr, Scope, lit};
25/// use vortex_scalar::Scalar;
26///
27/// let example = ConcatExpr::new(vec![
28///     lit(Scalar::from(100)),
29///     lit(Scalar::from(200)),
30///     lit(Scalar::from(300)),
31/// ]);
32/// let concatenated = example.evaluate(&Scope::empty(1)).unwrap();
33/// assert_eq!(concatenated.len(), 3);
34/// ```
35#[allow(clippy::derived_hash_with_manual_eq)]
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct ConcatExpr {
38    values: Vec<ExprRef>,
39}
40
41pub struct ConcatExprEncoding;
42
43impl VTable for ConcatVTable {
44    type Expr = ConcatExpr;
45    type Encoding = ConcatExprEncoding;
46    type Metadata = EmptyMetadata;
47
48    fn id(_encoding: &Self::Encoding) -> ExprId {
49        ExprId::new_ref("concat")
50    }
51
52    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
53        ExprEncodingRef::new_ref(ConcatExprEncoding.as_ref())
54    }
55
56    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
57        Some(EmptyMetadata)
58    }
59
60    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
61        expr.values.iter().collect()
62    }
63
64    fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
65        Ok(ConcatExpr { values: children })
66    }
67
68    fn build(
69        _encoding: &Self::Encoding,
70        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
71        children: Vec<ExprRef>,
72    ) -> VortexResult<Self::Expr> {
73        Ok(ConcatExpr { values: children })
74    }
75
76    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
77        if expr.values.is_empty() {
78            vortex_bail!("Concat expression must have at least one child");
79        }
80
81        let value_arrays = expr
82            .values
83            .iter()
84            .map(|value_expr| value_expr.unchecked_evaluate(scope))
85            .process_results(|it| it.collect::<Vec<_>>())?;
86
87        // Get the common dtype from the first array
88        let dtype = value_arrays[0].dtype().clone();
89
90        // Validate all arrays have the same dtype
91        for array in &value_arrays[1..] {
92            if array.dtype() != &dtype {
93                vortex_bail!(
94                    "All arrays in concat must have the same dtype, expected {:?} but got {:?}",
95                    dtype,
96                    array.dtype()
97                );
98            }
99        }
100
101        Ok(ChunkedArray::try_new(value_arrays, dtype)?.into_array())
102    }
103
104    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
105        if expr.values.is_empty() {
106            vortex_bail!("Concat expression must have at least one child");
107        }
108
109        // Return the dtype of the first child - all children must have the same dtype
110        let dtype = expr.values[0].return_dtype(scope)?;
111
112        // Validate all children have the same dtype
113        for value_expr in &expr.values[1..] {
114            let child_dtype = value_expr.return_dtype(scope)?;
115            if child_dtype != dtype {
116                vortex_bail!(
117                    "All expressions in concat must return the same dtype, expected {:?} but got {:?}",
118                    dtype,
119                    child_dtype
120                );
121            }
122        }
123
124        Ok(dtype)
125    }
126}
127
128impl ConcatExpr {
129    pub fn new(values: Vec<ExprRef>) -> Self {
130        ConcatExpr { values }
131    }
132
133    pub fn new_expr(values: Vec<ExprRef>) -> ExprRef {
134        Self::new(values).into_expr()
135    }
136
137    pub fn values(&self) -> &[ExprRef] {
138        &self.values
139    }
140}
141
142/// Creates an expression that concatenates multiple expressions into a single array.
143///
144/// All input expressions must evaluate to arrays of the same dtype.
145///
146/// ```rust
147/// # use vortex_expr::{concat, col, lit};
148/// # use vortex_scalar::Scalar;
149/// let expr = concat([col("chunk1"), col("chunk2"), lit(Scalar::from(42))]);
150/// ```
151pub fn concat(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
152    let values = elements.into_iter().map(|value| value.into()).collect_vec();
153    ConcatExpr::new(values).into_expr()
154}
155
156impl DisplayAs for ConcatExpr {
157    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
158        match df {
159            DisplayFormat::Compact => {
160                write!(f, "concat({})", self.values.iter().format(", "))
161            }
162            DisplayFormat::Tree => {
163                write!(f, "Concat")
164            }
165        }
166    }
167}
168
169impl AnalysisExpr for ConcatExpr {}
170
171#[cfg(test)]
172mod tests {
173    use vortex_array::arrays::ChunkedVTable;
174    use vortex_array::{IntoArray, ToCanonical};
175    use vortex_buffer::buffer;
176    use vortex_dtype::{DType, Nullability, PType};
177
178    use crate::{ConcatExpr, Scope, col, concat, lit, root};
179
180    fn test_array() -> vortex_array::ArrayRef {
181        vortex_array::arrays::StructArray::from_fields(&[
182            ("a", buffer![1, 2, 3].into_array()),
183            ("b", buffer![4, 5, 6].into_array()),
184        ])
185        .unwrap()
186        .into_array()
187    }
188
189    #[test]
190    pub fn test_concat_literals() {
191        let expr = ConcatExpr::new(vec![
192            lit(vortex_scalar::Scalar::from(1i32)),
193            lit(vortex_scalar::Scalar::from(2i32)),
194            lit(vortex_scalar::Scalar::from(3i32)),
195        ]);
196
197        // Literals expand to scope.len(), so use a scope of len 1
198        let scope_array = buffer![0i32].into_array();
199        let actual_array = expr.evaluate(&Scope::new(scope_array)).unwrap();
200
201        let chunked = actual_array.as_::<ChunkedVTable>();
202        assert_eq!(chunked.nchunks(), 3);
203        assert_eq!(chunked.len(), 3);
204
205        let canonical = chunked.to_canonical().into_array();
206        let primitive = canonical.to_primitive();
207        assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3]);
208    }
209
210    #[test]
211    pub fn test_concat_columns() {
212        let expr = ConcatExpr::new(vec![col("a"), col("b"), col("a")]);
213
214        let actual_array = expr.evaluate(&Scope::new(test_array())).unwrap();
215
216        let chunked = actual_array.as_::<ChunkedVTable>();
217        assert_eq!(chunked.nchunks(), 3);
218        assert_eq!(chunked.len(), 9);
219
220        let canonical = chunked.to_canonical().into_array();
221        let primitive = canonical.to_primitive();
222        assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 4, 5, 6, 1, 2, 3]);
223    }
224
225    #[test]
226    pub fn test_concat_mixed() {
227        let expr = ConcatExpr::new(vec![
228            col("a"),
229            lit(vortex_scalar::Scalar::from(99i32)),
230            col("b"),
231        ]);
232
233        let actual_array = expr.evaluate(&Scope::new(test_array())).unwrap();
234
235        let chunked = actual_array.as_::<ChunkedVTable>();
236        assert_eq!(chunked.nchunks(), 3);
237        // len = 3 (col a) + 3 (lit 99 expanded to scope.len()) + 3 (col b) = 9
238        assert_eq!(chunked.len(), 9);
239
240        let canonical = chunked.to_canonical().into_array();
241        let primitive = canonical.to_primitive();
242        assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 99, 99, 99, 4, 5, 6]);
243    }
244
245    #[test]
246    pub fn test_concat_dtype_mismatch() {
247        let expr = ConcatExpr::new(vec![
248            lit(vortex_scalar::Scalar::from(1i32)),
249            lit(vortex_scalar::Scalar::from(2i64)),
250        ]);
251
252        let result = expr.evaluate(&Scope::new(test_array()));
253        assert!(result.is_err());
254    }
255
256    #[test]
257    pub fn test_return_dtype() {
258        let expr = ConcatExpr::new(vec![
259            lit(vortex_scalar::Scalar::from(1i32)),
260            lit(vortex_scalar::Scalar::from(2i32)),
261        ]);
262
263        let dtype = expr
264            .return_dtype(&DType::Primitive(PType::I32, Nullability::NonNullable))
265            .unwrap();
266
267        assert_eq!(
268            dtype,
269            DType::Primitive(PType::I32, Nullability::NonNullable)
270        );
271    }
272
273    #[test]
274    pub fn test_display() {
275        let expr = concat([col("a"), col("b"), col("c")]);
276        assert_eq!(expr.to_string(), "concat($.a, $.b, $.c)");
277    }
278
279    #[test]
280    pub fn test_concat_with_root() {
281        let expr = concat([root(), root()]);
282
283        let test_array = buffer![1, 2, 3].into_array();
284        let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
285
286        let chunked = actual_array.as_::<ChunkedVTable>();
287        assert_eq!(chunked.nchunks(), 2);
288        assert_eq!(chunked.len(), 6);
289
290        let canonical = chunked.to_canonical().into_array();
291        let primitive = canonical.to_primitive();
292        assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 1, 2, 3]);
293    }
294}