Skip to main content

vortex_array/builders/
extension.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use vortex_error::VortexResult;
7use vortex_error::vortex_ensure;
8use vortex_mask::Mask;
9
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::arrays::ExtensionArray;
13use crate::builders::ArrayBuilder;
14use crate::builders::DEFAULT_BUILDER_CAPACITY;
15use crate::builders::builder_with_capacity;
16use crate::canonical::Canonical;
17use crate::canonical::ToCanonical;
18use crate::dtype::DType;
19use crate::dtype::extension::ExtDTypeRef;
20use crate::scalar::ExtScalar;
21use crate::scalar::Scalar;
22
23/// The builder for building a [`ExtensionArray`].
24pub struct ExtensionBuilder {
25    dtype: DType,
26    storage: Box<dyn ArrayBuilder>,
27}
28
29impl ExtensionBuilder {
30    /// Creates a new `ExtensionBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
31    pub fn new(ext_dtype: ExtDTypeRef) -> Self {
32        Self::with_capacity(ext_dtype, DEFAULT_BUILDER_CAPACITY)
33    }
34
35    /// Creates a new `ExtensionBuilder` with the given `capacity`.
36    pub fn with_capacity(ext_dtype: ExtDTypeRef, capacity: usize) -> Self {
37        Self {
38            storage: builder_with_capacity(ext_dtype.storage_dtype(), capacity),
39            dtype: DType::Extension(ext_dtype),
40        }
41    }
42
43    /// Appends an extension `value` to the builder.
44    pub fn append_value(&mut self, value: ExtScalar) -> VortexResult<()> {
45        self.storage.append_scalar(&value.to_storage_scalar())
46    }
47
48    /// Finishes the builder directly into a [`ExtensionArray`].
49    pub fn finish_into_extension(&mut self) -> ExtensionArray {
50        let storage = self.storage.finish();
51        ExtensionArray::new(self.ext_dtype(), storage)
52    }
53
54    /// The [`ExtDType`] of this builder.
55    fn ext_dtype(&self) -> ExtDTypeRef {
56        if let DType::Extension(ext_dtype) = &self.dtype {
57            ext_dtype.clone()
58        } else {
59            unreachable!()
60        }
61    }
62}
63
64impl ArrayBuilder for ExtensionBuilder {
65    fn as_any(&self) -> &dyn Any {
66        self
67    }
68
69    fn as_any_mut(&mut self) -> &mut dyn Any {
70        self
71    }
72
73    fn dtype(&self) -> &DType {
74        &self.dtype
75    }
76
77    fn len(&self) -> usize {
78        self.storage.len()
79    }
80
81    fn append_zeros(&mut self, n: usize) {
82        self.storage.append_zeros(n)
83    }
84
85    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
86        self.storage.append_nulls(n)
87    }
88
89    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
90        vortex_ensure!(
91            scalar.dtype() == self.dtype(),
92            "ExtensionBuilder expected scalar with dtype {}, got {}",
93            self.dtype(),
94            scalar.dtype()
95        );
96
97        self.append_value(scalar.as_extension())
98    }
99
100    unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
101        let ext_array = array.to_extension();
102        self.storage.extend_from_array(ext_array.storage())
103    }
104
105    fn reserve_exact(&mut self, capacity: usize) {
106        self.storage.reserve_exact(capacity)
107    }
108
109    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
110        unsafe { self.storage.set_validity_unchecked(validity) };
111    }
112
113    fn finish(&mut self) -> ArrayRef {
114        self.finish_into_extension().into_array()
115    }
116
117    fn finish_into_canonical(&mut self) -> Canonical {
118        Canonical::Extension(self.finish_into_extension())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::arrays::PrimitiveArray;
126    use crate::assert_arrays_eq;
127    use crate::builders::ArrayBuilder;
128    use crate::dtype::Nullability;
129    use crate::extension::datetime::Date;
130    use crate::extension::datetime::TimeUnit;
131    use crate::scalar::Scalar;
132
133    #[test]
134    fn test_append_scalar() {
135        let ext_dtype = Date::new(TimeUnit::Days, Nullability::Nullable).erased();
136
137        let mut builder = ExtensionBuilder::new(ext_dtype.clone());
138
139        // Test appending a valid extension value.
140        let storage1 = Scalar::from(Some(42i32));
141        let ext_scalar1 = Scalar::extension::<Date>(TimeUnit::Days, storage1);
142        builder.append_scalar(&ext_scalar1).unwrap();
143
144        // Test appending another value.
145        let storage2 = Scalar::from(Some(84i32));
146        let ext_scalar2 = Scalar::extension::<Date>(TimeUnit::Days, storage2);
147        builder.append_scalar(&ext_scalar2).unwrap();
148
149        // Test appending null value.
150        let null_storage = Scalar::null(DType::Primitive(
151            crate::dtype::PType::I32,
152            Nullability::Nullable,
153        ));
154        let null_scalar = Scalar::extension::<Date>(TimeUnit::Days, null_storage);
155        builder.append_scalar(&null_scalar).unwrap();
156
157        let array = builder.finish_into_extension();
158        let expected = ExtensionArray::new(
159            ext_dtype.clone(),
160            PrimitiveArray::from_option_iter([Some(42i32), Some(84), None]).into_array(),
161        );
162
163        assert_arrays_eq!(&array, &expected);
164        assert_eq!(array.len(), 3);
165
166        // Test wrong dtype error.
167        let mut builder = ExtensionBuilder::new(ext_dtype);
168        let wrong_scalar = Scalar::from(true);
169        assert!(builder.append_scalar(&wrong_scalar).is_err());
170    }
171}