vortex_array/builders/
extension.rs1use std::any::Any;
5use std::sync::Arc;
6
7use vortex_dtype::{DType, ExtDType};
8use vortex_error::{VortexResult, vortex_ensure};
9use vortex_mask::Mask;
10use vortex_scalar::{ExtScalar, Scalar};
11
12use crate::arrays::ExtensionArray;
13use crate::builders::{ArrayBuilder, DEFAULT_BUILDER_CAPACITY, builder_with_capacity};
14use crate::canonical::{Canonical, ToCanonical};
15use crate::{Array, ArrayRef, IntoArray};
16
17pub struct ExtensionBuilder {
19 dtype: DType,
20 storage: Box<dyn ArrayBuilder>,
21}
22
23impl ExtensionBuilder {
24 pub fn new(ext_dtype: Arc<ExtDType>) -> Self {
26 Self::with_capacity(ext_dtype, DEFAULT_BUILDER_CAPACITY)
27 }
28
29 pub fn with_capacity(ext_dtype: Arc<ExtDType>, capacity: usize) -> Self {
31 Self {
32 storage: builder_with_capacity(ext_dtype.storage_dtype(), capacity),
33 dtype: DType::Extension(ext_dtype),
34 }
35 }
36
37 pub fn append_value(&mut self, value: ExtScalar) -> VortexResult<()> {
39 self.storage.append_scalar(&value.storage())
40 }
41
42 pub fn finish_into_extension(&mut self) -> ExtensionArray {
44 let storage = self.storage.finish();
45 ExtensionArray::new(self.ext_dtype(), storage)
46 }
47
48 fn ext_dtype(&self) -> Arc<ExtDType> {
50 if let DType::Extension(ext_dtype) = &self.dtype {
51 ext_dtype.clone()
52 } else {
53 unreachable!()
54 }
55 }
56}
57
58impl ArrayBuilder for ExtensionBuilder {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn as_any_mut(&mut self) -> &mut dyn Any {
64 self
65 }
66
67 fn dtype(&self) -> &DType {
68 &self.dtype
69 }
70
71 fn len(&self) -> usize {
72 self.storage.len()
73 }
74
75 fn append_zeros(&mut self, n: usize) {
76 self.storage.append_zeros(n)
77 }
78
79 unsafe fn append_nulls_unchecked(&mut self, n: usize) {
80 self.storage.append_nulls(n)
81 }
82
83 fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
84 vortex_ensure!(
85 scalar.dtype() == self.dtype(),
86 "ExtensionBuilder expected scalar with dtype {:?}, got {:?}",
87 self.dtype(),
88 scalar.dtype()
89 );
90
91 let ext_scalar = ExtScalar::try_from(scalar)?;
92 self.append_value(ext_scalar)
93 }
94
95 unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) {
96 let ext_array = array.to_extension();
97 self.storage.extend_from_array(ext_array.storage())
98 }
99
100 fn ensure_capacity(&mut self, capacity: usize) {
101 self.storage.ensure_capacity(capacity)
102 }
103
104 fn set_validity(&mut self, validity: Mask) {
105 self.storage.set_validity(validity);
106 }
107
108 fn finish(&mut self) -> ArrayRef {
109 self.finish_into_extension().into_array()
110 }
111
112 fn finish_into_canonical(&mut self) -> Canonical {
113 Canonical::Extension(self.finish_into_extension())
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use vortex_dtype::{ExtDType, ExtID, Nullability};
120 use vortex_scalar::Scalar;
121
122 use super::*;
123 use crate::builders::ArrayBuilder;
124
125 #[test]
126 fn test_append_scalar() {
127 let ext_dtype = Arc::new(ExtDType::new(
128 ExtID::new("test_ext".into()),
129 Arc::new(DType::Primitive(
130 vortex_dtype::PType::I32,
131 Nullability::Nullable,
132 )),
133 None,
134 ));
135
136 let mut builder = ExtensionBuilder::new(ext_dtype.clone());
137
138 let storage1 = Scalar::from(42i32);
140 let ext_scalar1 = Scalar::extension(ext_dtype.clone(), storage1);
141 builder.append_scalar(&ext_scalar1).unwrap();
142
143 let storage2 = Scalar::from(84i32);
145 let ext_scalar2 = Scalar::extension(ext_dtype.clone(), storage2);
146 builder.append_scalar(&ext_scalar2).unwrap();
147
148 let null_storage = Scalar::null(DType::Primitive(
150 vortex_dtype::PType::I32,
151 Nullability::Nullable,
152 ));
153 let null_scalar = Scalar::extension(ext_dtype.clone(), null_storage);
154 builder.append_scalar(&null_scalar).unwrap();
155
156 let array = builder.finish_into_extension();
157 assert_eq!(array.len(), 3);
158
159 let scalar0 = array.scalar_at(0);
162 let ext0 = scalar0.as_extension();
163 assert_eq!(ext0.storage().as_primitive().typed_value::<i32>(), Some(42));
164
165 let scalar1 = array.scalar_at(1);
166 let ext1 = scalar1.as_extension();
167 assert_eq!(ext1.storage().as_primitive().typed_value::<i32>(), Some(84));
168
169 let scalar2 = array.scalar_at(2);
170 let ext2 = scalar2.as_extension();
171 assert_eq!(ext2.storage().as_primitive().typed_value::<i32>(), None); let mut builder = ExtensionBuilder::new(ext_dtype);
175 let wrong_scalar = Scalar::from(true);
176 assert!(builder.append_scalar(&wrong_scalar).is_err());
177 }
178}