vortex_array/arrow/
record_batch.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::RecordBatch;
5use arrow_array::cast::AsArray;
6use arrow_schema::{DataType, Schema};
7use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_ensure};
8
9use crate::arrays::StructArray;
10use crate::arrow::compute::{to_arrow, to_arrow_preferred};
11use crate::{Array, Canonical};
12
13impl TryFrom<&dyn Array> for RecordBatch {
14    type Error = VortexError;
15
16    fn try_from(value: &dyn Array) -> VortexResult<Self> {
17        let Canonical::Struct(struct_array) = value.to_canonical() else {
18            vortex_bail!("RecordBatch can only be constructed from ")
19        };
20
21        vortex_ensure!(
22            struct_array.all_valid(),
23            "RecordBatch can only be constructed from StructArray with no nulls"
24        );
25
26        let array_ref = to_arrow_preferred(struct_array.as_ref())?;
27        Ok(RecordBatch::from(array_ref.as_struct()))
28    }
29}
30
31impl StructArray {
32    pub fn into_record_batch_with_schema(
33        self,
34        schema: impl AsRef<Schema>,
35    ) -> VortexResult<RecordBatch> {
36        let data_type = DataType::Struct(schema.as_ref().fields.clone());
37        let array_ref = to_arrow(self.as_ref(), &data_type)?;
38        Ok(RecordBatch::from(array_ref.as_struct()))
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use std::sync::Arc;
45
46    use arrow_schema::{DataType, Field, FieldRef, Schema};
47    use vortex_dtype::{DType, Nullability, PType};
48    use vortex_scalar::Scalar;
49
50    use crate::arrays::StructArray;
51    use crate::builders::{ArrayBuilder, ListBuilder};
52
53    #[test]
54    fn test_into_rb_with_schema() {
55        let mut xs = ListBuilder::<u32>::new(
56            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
57            Nullability::Nullable,
58        );
59
60        xs.append_scalar(&Scalar::list(
61            xs.element_dtype().clone(),
62            vec![1i32.into(), 2i32.into(), 3i32.into()],
63            Nullability::Nullable,
64        ))
65        .unwrap();
66        xs.append_null();
67        xs.append_zero();
68
69        let xs = xs.finish();
70
71        let array = StructArray::from_fields(&[("xs", xs)]).unwrap();
72
73        // Explicitly request a conversion to LargeListView type instead of the preferred type.
74        let arrow_schema = Arc::new(Schema::new(vec![Field::new(
75            "xs",
76            DataType::LargeListView(FieldRef::new(Field::new_list_field(DataType::Int32, false))),
77            true,
78        )]));
79        let rb = array.into_record_batch_with_schema(arrow_schema).unwrap();
80
81        let xs = rb.column(0);
82        assert_eq!(
83            xs.data_type(),
84            &DataType::LargeListView(FieldRef::new(Field::new_list_field(DataType::Int32, false)))
85        );
86    }
87}