vortex_array/arrays/variant/vtable/
mod.rs1mod kernel;
5mod operations;
6mod validity;
7
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_proto::dtype as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16use vortex_utils::aliases::hash_set::HashSet;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::ExecutionResult;
21use crate::array::Array;
22use crate::array::ArrayId;
23use crate::array::ArrayParts;
24use crate::array::ArrayView;
25use crate::array::EmptyArrayData;
26use crate::array::VTable;
27use crate::array::with_empty_buffers;
28use crate::arrays::variant::CORE_STORAGE_SLOT;
29use crate::arrays::variant::NUM_SLOTS;
30use crate::arrays::variant::SHREDDED_SLOT;
31use crate::arrays::variant::SLOT_NAMES;
32use crate::arrays::variant::compute::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::dtype::FieldName;
36use crate::dtype::FieldNames;
37use crate::dtype::Nullability;
38use crate::dtype::StructFields;
39use crate::scalar::Scalar;
40use crate::scalar::ScalarValue;
41use crate::serde::ArrayChildren;
42
43pub type VariantArray = Array<Variant>;
45
46pub(crate) fn initialize(session: &VortexSession) {
47 kernel::initialize(session);
48}
49
50#[derive(Clone, Debug)]
51pub struct Variant;
52
53#[derive(Clone, prost::Message)]
54struct VariantMetadataProto {
55 #[prost(message, optional, tag = "1")]
56 pub shredded_dtype: Option<pb::DType>,
57}
58
59impl VTable for Variant {
60 type TypedArrayData = EmptyArrayData;
61
62 type OperationsVTable = Self;
63
64 type ValidityVTable = Self;
65
66 fn id(&self) -> ArrayId {
67 static ID: CachedId = CachedId::new("vortex.variant");
68 *ID
69 }
70
71 fn validate(
72 &self,
73 _data: &Self::TypedArrayData,
74 dtype: &DType,
75 len: usize,
76 slots: &[Option<ArrayRef>],
77 ) -> VortexResult<()> {
78 vortex_ensure!(
79 slots.len() == NUM_SLOTS,
80 "VariantArray expects {NUM_SLOTS} slots, got {}",
81 slots.len()
82 );
83 vortex_ensure!(
84 slots[CORE_STORAGE_SLOT].is_some(),
85 "VariantArray core_storage slot must be present"
86 );
87 let core_storage = slots[CORE_STORAGE_SLOT]
88 .as_ref()
89 .vortex_expect("validated core_storage slot presence");
90 vortex_ensure!(
91 matches!(dtype, DType::Variant(_)),
92 "Expected Variant DType, got {dtype}"
93 );
94 vortex_ensure!(
95 core_storage.dtype() == dtype,
96 "VariantArray core_storage dtype {} does not match outer dtype {}",
97 core_storage.dtype(),
98 dtype
99 );
100 vortex_ensure!(
101 core_storage.len() == len,
102 "VariantArray core_storage length {} does not match outer length {}",
103 core_storage.len(),
104 len
105 );
106 if let Some(shredded) = slots[SHREDDED_SLOT].as_ref() {
107 vortex_ensure!(
108 shredded.len() == len,
109 "VariantArray shredded length {} does not match outer length {}",
110 shredded.len(),
111 len
112 );
113 }
114 Ok(())
115 }
116
117 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
118 0
119 }
120
121 fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
122 vortex_panic!("VariantArray buffer index {idx} out of bounds")
123 }
124
125 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
126 None
127 }
128
129 fn with_buffers(
130 &self,
131 array: ArrayView<'_, Self>,
132 buffers: &[BufferHandle],
133 ) -> VortexResult<ArrayParts<Self>> {
134 with_empty_buffers(self, array, buffers)
135 }
136
137 fn serialize(
138 array: ArrayView<'_, Self>,
139 _session: &VortexSession,
140 ) -> VortexResult<Option<Vec<u8>>> {
141 let shredded_dtype = array.slots()[SHREDDED_SLOT]
142 .as_ref()
143 .map(|shredded| shredded.dtype().try_into())
144 .transpose()?;
145 Ok(Some(
146 VariantMetadataProto { shredded_dtype }.encode_to_vec(),
147 ))
148 }
149
150 fn deserialize(
151 &self,
152 dtype: &DType,
153 len: usize,
154 metadata: &[u8],
155 buffers: &[BufferHandle],
156 children: &dyn ArrayChildren,
157 session: &VortexSession,
158 ) -> VortexResult<ArrayParts<Self>> {
159 vortex_ensure!(
160 buffers.is_empty(),
161 "VariantArray expects 0 buffers, got {}",
162 buffers.len()
163 );
164 let proto = VariantMetadataProto::decode(metadata)?;
165 let shredded_dtype = proto
166 .shredded_dtype
167 .as_ref()
168 .map(|dtype| DType::from_proto(dtype, session))
169 .transpose()?;
170 vortex_ensure!(matches!(dtype, DType::Variant(_)), "Expected Variant DType");
171 let expected_children = 1 + usize::from(shredded_dtype.is_some());
172 vortex_ensure!(
173 children.len() == expected_children,
174 "Expected {} children, got {}",
175 expected_children,
176 children.len(),
177 );
178 let core_storage = children.get(0, dtype, len)?;
179 let shredded = shredded_dtype
180 .map(|dtype| children.get(1, &dtype, len))
181 .transpose()?;
182 Ok(
183 ArrayParts::new(self.clone(), dtype.clone(), len, EmptyArrayData)
184 .with_slots(vec![Some(core_storage), shredded].into()),
185 )
186 }
187
188 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
189 match SLOT_NAMES.get(idx) {
190 Some(name) => (*name).to_string(),
191 None => vortex_panic!("VariantArray slot_name index {idx} out of bounds"),
192 }
193 }
194
195 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
196 Ok(ExecutionResult::done(array))
197 }
198
199 fn reduce_parent(
200 array: ArrayView<'_, Self>,
201 parent: &ArrayRef,
202 child_idx: usize,
203 ) -> VortexResult<Option<ArrayRef>> {
204 RULES.evaluate(array, parent, child_idx)
205 }
206}
207
208fn merge_typed_scalar_as_variant(
209 typed_scalar: Scalar,
210 fallback_scalar: Option<Scalar>,
211 dtype: &DType,
212) -> VortexResult<Scalar> {
213 let scalar = if typed_scalar.is_null() {
214 fallback_scalar.unwrap_or_else(|| Scalar::null(dtype.clone()))
215 } else if matches!(
216 typed_scalar.dtype(),
217 DType::List(..) | DType::FixedSizeList(..)
218 ) {
219 Scalar::variant(typed_list_as_variant_payload(typed_scalar)?)
220 } else if typed_scalar.dtype().is_struct() {
221 merge_typed_object_as_variant(typed_scalar, fallback_scalar)?
222 } else if typed_scalar.dtype().is_variant() {
223 typed_scalar
224 } else {
225 Scalar::variant(typed_scalar)
226 };
227
228 if scalar.dtype() == dtype {
229 Ok(scalar)
230 } else {
231 scalar.cast(dtype)
232 }
233}
234
235fn typed_list_as_variant_payload(typed_scalar: Scalar) -> VortexResult<Scalar> {
236 let list = typed_scalar.as_list();
237 let elements = list
238 .elements()
239 .unwrap_or_default()
240 .into_iter()
241 .map(|element| {
242 if element.dtype().is_variant() {
243 element
244 } else {
245 Scalar::variant(element)
246 }
247 })
248 .collect();
249 Ok(Scalar::list(
250 DType::Variant(Nullability::NonNullable),
251 elements,
252 Nullability::NonNullable,
253 ))
254}
255
256fn merge_typed_object_as_variant(
257 typed_scalar: Scalar,
258 fallback_scalar: Option<Scalar>,
259) -> VortexResult<Scalar> {
260 let fallback_inner = fallback_scalar
261 .as_ref()
262 .and_then(|scalar| scalar.as_variant().value())
263 .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null());
264 let Some(fallback_inner) = fallback_inner else {
265 return Ok(Scalar::variant(typed_scalar));
266 };
267
268 merge_struct_payload(&typed_scalar, Some(fallback_inner)).map(Scalar::variant)
269}
270
271fn merge_struct_payload(typed: &Scalar, raw: Option<&Scalar>) -> VortexResult<Scalar> {
272 let typed_struct = typed.as_struct();
273 let raw_struct = raw
274 .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null())
275 .map(Scalar::as_struct);
276 let mut present_typed_fields = HashSet::new();
277 let mut names = Vec::new();
278 let mut values = Vec::new();
279
280 for name in typed_struct.names().iter() {
281 let Some(typed_field) = typed_struct.field(name.as_ref()) else {
282 continue;
283 };
284 if typed_field.is_null() {
285 continue;
286 }
287
288 let raw_field = raw_struct.and_then(|raw_struct| raw_struct.field(name.as_ref()));
289 let raw_payload = raw_field.as_ref().and_then(|scalar| {
290 if scalar.dtype().is_variant() {
291 scalar.as_variant().value()
292 } else {
293 Some(scalar)
294 }
295 });
296 let field = if typed_field.dtype().is_struct()
297 && raw_payload.is_some_and(|raw| raw.dtype().is_struct() && !raw.is_null())
298 {
299 Scalar::variant(merge_struct_payload(&typed_field, raw_payload)?)
300 } else if typed_field.dtype().is_variant() {
301 typed_field.cast(&DType::Variant(Nullability::NonNullable))?
302 } else {
303 Scalar::variant(typed_field)
304 };
305
306 present_typed_fields.insert(name.as_ref().to_string());
307 names.push(FieldName::from(name.as_ref()));
308 values.push(field.into_value());
309 }
310
311 if let Some(raw_struct) = raw_struct {
312 for name in raw_struct.names().iter() {
313 if present_typed_fields.contains(name.as_ref()) {
314 continue;
315 }
316 let Some(raw_field) = raw_struct.field(name.as_ref()) else {
317 continue;
318 };
319 if raw_field.is_null() {
320 continue;
321 }
322 let raw_field = if raw_field.dtype().is_variant() {
323 raw_field.cast(&DType::Variant(Nullability::NonNullable))?
324 } else {
325 Scalar::variant(raw_field)
326 };
327 names.push(FieldName::from(name.as_ref()));
328 values.push(raw_field.into_value());
329 }
330 }
331
332 let fields = StructFields::new(
333 FieldNames::from(names),
334 vec![DType::Variant(Nullability::NonNullable); values.len()],
335 );
336 Scalar::try_new(
337 DType::Struct(fields, Nullability::NonNullable),
338 Some(ScalarValue::Tuple(values)),
339 )
340}
341
342#[cfg(test)]
343mod tests {}