vortex_array/builders/
extension.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use vortex_dtype::{DType, ExtDType};
8use vortex_error::{VortexResult, vortex_ensure};
9use vortex_mask::Mask;
10use vortex_scalar::{ExtScalar, Scalar};
11
12use crate::arrays::ExtensionArray;
13use crate::builders::{ArrayBuilder, DEFAULT_BUILDER_CAPACITY, builder_with_capacity};
14use crate::canonical::{Canonical, ToCanonical};
15use crate::{Array, ArrayRef, IntoArray};
16
17/// The builder for building a [`ExtensionArray`].
18pub struct ExtensionBuilder {
19    dtype: DType,
20    storage: Box<dyn ArrayBuilder>,
21}
22
23impl ExtensionBuilder {
24    /// Creates a new `ExtensionBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
25    pub fn new(ext_dtype: Arc<ExtDType>) -> Self {
26        Self::with_capacity(ext_dtype, DEFAULT_BUILDER_CAPACITY)
27    }
28
29    /// Creates a new `ExtensionBuilder` with the given `capacity`.
30    pub fn with_capacity(ext_dtype: Arc<ExtDType>, capacity: usize) -> Self {
31        Self {
32            storage: builder_with_capacity(ext_dtype.storage_dtype(), capacity),
33            dtype: DType::Extension(ext_dtype),
34        }
35    }
36
37    /// Appends an extension `value` to the builder.
38    pub fn append_value(&mut self, value: ExtScalar) -> VortexResult<()> {
39        self.storage.append_scalar(&value.storage())
40    }
41
42    /// Finishes the builder directly into a [`ExtensionArray`].
43    pub fn finish_into_extension(&mut self) -> ExtensionArray {
44        let storage = self.storage.finish();
45        ExtensionArray::new(self.ext_dtype(), storage)
46    }
47
48    /// The [`ExtDType`] of this builder.
49    fn ext_dtype(&self) -> Arc<ExtDType> {
50        if let DType::Extension(ext_dtype) = &self.dtype {
51            ext_dtype.clone()
52        } else {
53            unreachable!()
54        }
55    }
56}
57
58impl ArrayBuilder for ExtensionBuilder {
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62
63    fn as_any_mut(&mut self) -> &mut dyn Any {
64        self
65    }
66
67    fn dtype(&self) -> &DType {
68        &self.dtype
69    }
70
71    fn len(&self) -> usize {
72        self.storage.len()
73    }
74
75    fn append_zeros(&mut self, n: usize) {
76        self.storage.append_zeros(n)
77    }
78
79    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
80        self.storage.append_nulls(n)
81    }
82
83    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
84        vortex_ensure!(
85            scalar.dtype() == self.dtype(),
86            "ExtensionBuilder expected scalar with dtype {:?}, got {:?}",
87            self.dtype(),
88            scalar.dtype()
89        );
90
91        let ext_scalar = ExtScalar::try_from(scalar)?;
92        self.append_value(ext_scalar)
93    }
94
95    unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) {
96        let ext_array = array.to_extension();
97        self.storage.extend_from_array(ext_array.storage())
98    }
99
100    fn ensure_capacity(&mut self, capacity: usize) {
101        self.storage.ensure_capacity(capacity)
102    }
103
104    fn set_validity(&mut self, validity: Mask) {
105        self.storage.set_validity(validity);
106    }
107
108    fn finish(&mut self) -> ArrayRef {
109        self.finish_into_extension().into_array()
110    }
111
112    fn finish_into_canonical(&mut self) -> Canonical {
113        Canonical::Extension(self.finish_into_extension())
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use vortex_dtype::{ExtDType, ExtID, Nullability};
120    use vortex_scalar::Scalar;
121
122    use super::*;
123    use crate::builders::ArrayBuilder;
124
125    #[test]
126    fn test_append_scalar() {
127        let ext_dtype = Arc::new(ExtDType::new(
128            ExtID::new("test_ext".into()),
129            Arc::new(DType::Primitive(
130                vortex_dtype::PType::I32,
131                Nullability::Nullable,
132            )),
133            None,
134        ));
135
136        let mut builder = ExtensionBuilder::new(ext_dtype.clone());
137
138        // Test appending a valid extension value.
139        let storage1 = Scalar::from(42i32);
140        let ext_scalar1 = Scalar::extension(ext_dtype.clone(), storage1);
141        builder.append_scalar(&ext_scalar1).unwrap();
142
143        // Test appending another value.
144        let storage2 = Scalar::from(84i32);
145        let ext_scalar2 = Scalar::extension(ext_dtype.clone(), storage2);
146        builder.append_scalar(&ext_scalar2).unwrap();
147
148        // Test appending null value.
149        let null_storage = Scalar::null(DType::Primitive(
150            vortex_dtype::PType::I32,
151            Nullability::Nullable,
152        ));
153        let null_scalar = Scalar::extension(ext_dtype.clone(), null_storage);
154        builder.append_scalar(&null_scalar).unwrap();
155
156        let array = builder.finish_into_extension();
157        assert_eq!(array.len(), 3);
158
159        // Check actual values using scalar_at.
160
161        let scalar0 = array.scalar_at(0);
162        let ext0 = scalar0.as_extension();
163        assert_eq!(ext0.storage().as_primitive().typed_value::<i32>(), Some(42));
164
165        let scalar1 = array.scalar_at(1);
166        let ext1 = scalar1.as_extension();
167        assert_eq!(ext1.storage().as_primitive().typed_value::<i32>(), Some(84));
168
169        let scalar2 = array.scalar_at(2);
170        let ext2 = scalar2.as_extension();
171        assert_eq!(ext2.storage().as_primitive().typed_value::<i32>(), None); // Storage is null.
172
173        // Test wrong dtype error.
174        let mut builder = ExtensionBuilder::new(ext_dtype);
175        let wrong_scalar = Scalar::from(true);
176        assert!(builder.append_scalar(&wrong_scalar).is_err());
177    }
178}