vortex_array/builders/
extension.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use vortex_dtype::DType;
8use vortex_dtype::ExtDType;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_mask::Mask;
12use vortex_scalar::ExtScalar;
13use vortex_scalar::Scalar;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::arrays::ExtensionArray;
19use crate::builders::ArrayBuilder;
20use crate::builders::DEFAULT_BUILDER_CAPACITY;
21use crate::builders::builder_with_capacity;
22use crate::canonical::Canonical;
23use crate::canonical::ToCanonical;
24
25/// The builder for building a [`ExtensionArray`].
26pub struct ExtensionBuilder {
27    dtype: DType,
28    storage: Box<dyn ArrayBuilder>,
29}
30
31impl ExtensionBuilder {
32    /// Creates a new `ExtensionBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
33    pub fn new(ext_dtype: Arc<ExtDType>) -> Self {
34        Self::with_capacity(ext_dtype, DEFAULT_BUILDER_CAPACITY)
35    }
36
37    /// Creates a new `ExtensionBuilder` with the given `capacity`.
38    pub fn with_capacity(ext_dtype: Arc<ExtDType>, capacity: usize) -> Self {
39        Self {
40            storage: builder_with_capacity(ext_dtype.storage_dtype(), capacity),
41            dtype: DType::Extension(ext_dtype),
42        }
43    }
44
45    /// Appends an extension `value` to the builder.
46    pub fn append_value(&mut self, value: ExtScalar) -> VortexResult<()> {
47        self.storage.append_scalar(&value.storage())
48    }
49
50    /// Finishes the builder directly into a [`ExtensionArray`].
51    pub fn finish_into_extension(&mut self) -> ExtensionArray {
52        let storage = self.storage.finish();
53        ExtensionArray::new(self.ext_dtype(), storage)
54    }
55
56    /// The [`ExtDType`] of this builder.
57    fn ext_dtype(&self) -> Arc<ExtDType> {
58        if let DType::Extension(ext_dtype) = &self.dtype {
59            ext_dtype.clone()
60        } else {
61            unreachable!()
62        }
63    }
64}
65
66impl ArrayBuilder for ExtensionBuilder {
67    fn as_any(&self) -> &dyn Any {
68        self
69    }
70
71    fn as_any_mut(&mut self) -> &mut dyn Any {
72        self
73    }
74
75    fn dtype(&self) -> &DType {
76        &self.dtype
77    }
78
79    fn len(&self) -> usize {
80        self.storage.len()
81    }
82
83    fn append_zeros(&mut self, n: usize) {
84        self.storage.append_zeros(n)
85    }
86
87    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
88        self.storage.append_nulls(n)
89    }
90
91    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
92        vortex_ensure!(
93            scalar.dtype() == self.dtype(),
94            "ExtensionBuilder expected scalar with dtype {:?}, got {:?}",
95            self.dtype(),
96            scalar.dtype()
97        );
98
99        let ext_scalar = ExtScalar::try_from(scalar)?;
100        self.append_value(ext_scalar)
101    }
102
103    unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) {
104        let ext_array = array.to_extension();
105        self.storage.extend_from_array(ext_array.storage())
106    }
107
108    fn reserve_exact(&mut self, capacity: usize) {
109        self.storage.reserve_exact(capacity)
110    }
111
112    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
113        unsafe { self.storage.set_validity_unchecked(validity) };
114    }
115
116    fn finish(&mut self) -> ArrayRef {
117        self.finish_into_extension().into_array()
118    }
119
120    fn finish_into_canonical(&mut self) -> Canonical {
121        Canonical::Extension(self.finish_into_extension())
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use vortex_dtype::ExtDType;
128    use vortex_dtype::ExtID;
129    use vortex_dtype::Nullability;
130    use vortex_scalar::Scalar;
131
132    use super::*;
133    use crate::arrays::PrimitiveArray;
134    use crate::assert_arrays_eq;
135    use crate::builders::ArrayBuilder;
136
137    #[test]
138    fn test_append_scalar() {
139        let ext_dtype = Arc::new(ExtDType::new(
140            ExtID::new("test_ext".into()),
141            Arc::new(DType::Primitive(
142                vortex_dtype::PType::I32,
143                Nullability::Nullable,
144            )),
145            None,
146        ));
147
148        let mut builder = ExtensionBuilder::new(ext_dtype.clone());
149
150        // Test appending a valid extension value.
151        let storage1 = Scalar::from(42i32);
152        let ext_scalar1 = Scalar::extension(ext_dtype.clone(), storage1);
153        builder.append_scalar(&ext_scalar1).unwrap();
154
155        // Test appending another value.
156        let storage2 = Scalar::from(84i32);
157        let ext_scalar2 = Scalar::extension(ext_dtype.clone(), storage2);
158        builder.append_scalar(&ext_scalar2).unwrap();
159
160        // Test appending null value.
161        let null_storage = Scalar::null(DType::Primitive(
162            vortex_dtype::PType::I32,
163            Nullability::Nullable,
164        ));
165        let null_scalar = Scalar::extension(ext_dtype.clone(), null_storage);
166        builder.append_scalar(&null_scalar).unwrap();
167
168        let array = builder.finish_into_extension();
169        let expected = ExtensionArray::new(
170            ext_dtype.clone(),
171            PrimitiveArray::from_option_iter([Some(42i32), Some(84), None]).into_array(),
172        );
173
174        assert_arrays_eq!(&array, &expected);
175        assert_eq!(array.len(), 3);
176
177        // Test wrong dtype error.
178        let mut builder = ExtensionBuilder::new(ext_dtype);
179        let wrong_scalar = Scalar::from(true);
180        assert!(builder.append_scalar(&wrong_scalar).is_err());
181    }
182}