vortex_array/builders/
extension.rs1use std::any::Any;
5
6use vortex_error::VortexResult;
7use vortex_error::vortex_ensure;
8use vortex_mask::Mask;
9
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::arrays::ExtensionArray;
13use crate::builders::ArrayBuilder;
14use crate::builders::DEFAULT_BUILDER_CAPACITY;
15use crate::builders::builder_with_capacity;
16use crate::canonical::Canonical;
17use crate::canonical::ToCanonical;
18use crate::dtype::DType;
19use crate::dtype::extension::ExtDTypeRef;
20use crate::scalar::ExtScalar;
21use crate::scalar::Scalar;
22
23pub struct ExtensionBuilder {
25 dtype: DType,
26 storage: Box<dyn ArrayBuilder>,
27}
28
29impl ExtensionBuilder {
30 pub fn new(ext_dtype: ExtDTypeRef) -> Self {
32 Self::with_capacity(ext_dtype, DEFAULT_BUILDER_CAPACITY)
33 }
34
35 pub fn with_capacity(ext_dtype: ExtDTypeRef, capacity: usize) -> Self {
37 Self {
38 storage: builder_with_capacity(ext_dtype.storage_dtype(), capacity),
39 dtype: DType::Extension(ext_dtype),
40 }
41 }
42
43 pub fn append_value(&mut self, value: ExtScalar) -> VortexResult<()> {
45 self.storage.append_scalar(&value.to_storage_scalar())
46 }
47
48 pub fn finish_into_extension(&mut self) -> ExtensionArray {
50 let storage = self.storage.finish();
51 ExtensionArray::new(self.ext_dtype(), storage)
52 }
53
54 fn ext_dtype(&self) -> ExtDTypeRef {
56 if let DType::Extension(ext_dtype) = &self.dtype {
57 ext_dtype.clone()
58 } else {
59 unreachable!()
60 }
61 }
62}
63
64impl ArrayBuilder for ExtensionBuilder {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn as_any_mut(&mut self) -> &mut dyn Any {
70 self
71 }
72
73 fn dtype(&self) -> &DType {
74 &self.dtype
75 }
76
77 fn len(&self) -> usize {
78 self.storage.len()
79 }
80
81 fn append_zeros(&mut self, n: usize) {
82 self.storage.append_zeros(n)
83 }
84
85 unsafe fn append_nulls_unchecked(&mut self, n: usize) {
86 self.storage.append_nulls(n)
87 }
88
89 fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
90 vortex_ensure!(
91 scalar.dtype() == self.dtype(),
92 "ExtensionBuilder expected scalar with dtype {}, got {}",
93 self.dtype(),
94 scalar.dtype()
95 );
96
97 self.append_value(scalar.as_extension())
98 }
99
100 unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
101 let ext_array = array.to_extension();
102 self.storage.extend_from_array(ext_array.storage())
103 }
104
105 fn reserve_exact(&mut self, capacity: usize) {
106 self.storage.reserve_exact(capacity)
107 }
108
109 unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
110 unsafe { self.storage.set_validity_unchecked(validity) };
111 }
112
113 fn finish(&mut self) -> ArrayRef {
114 self.finish_into_extension().into_array()
115 }
116
117 fn finish_into_canonical(&mut self) -> Canonical {
118 Canonical::Extension(self.finish_into_extension())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::arrays::PrimitiveArray;
126 use crate::assert_arrays_eq;
127 use crate::builders::ArrayBuilder;
128 use crate::dtype::Nullability;
129 use crate::extension::datetime::Date;
130 use crate::extension::datetime::TimeUnit;
131 use crate::scalar::Scalar;
132
133 #[test]
134 fn test_append_scalar() {
135 let ext_dtype = Date::new(TimeUnit::Days, Nullability::Nullable).erased();
136
137 let mut builder = ExtensionBuilder::new(ext_dtype.clone());
138
139 let storage1 = Scalar::from(Some(42i32));
141 let ext_scalar1 = Scalar::extension::<Date>(TimeUnit::Days, storage1);
142 builder.append_scalar(&ext_scalar1).unwrap();
143
144 let storage2 = Scalar::from(Some(84i32));
146 let ext_scalar2 = Scalar::extension::<Date>(TimeUnit::Days, storage2);
147 builder.append_scalar(&ext_scalar2).unwrap();
148
149 let null_storage = Scalar::null(DType::Primitive(
151 crate::dtype::PType::I32,
152 Nullability::Nullable,
153 ));
154 let null_scalar = Scalar::extension::<Date>(TimeUnit::Days, null_storage);
155 builder.append_scalar(&null_scalar).unwrap();
156
157 let array = builder.finish_into_extension();
158 let expected = ExtensionArray::new(
159 ext_dtype.clone(),
160 PrimitiveArray::from_option_iter([Some(42i32), Some(84), None]).into_array(),
161 );
162
163 assert_arrays_eq!(&array, &expected);
164 assert_eq!(array.len(), 3);
165
166 let mut builder = ExtensionBuilder::new(ext_dtype);
168 let wrong_scalar = Scalar::from(true);
169 assert!(builder.append_scalar(&wrong_scalar).is_err());
170 }
171}