vortex_array/arrays/variant/vtable/
mod.rs1mod kernel;
5mod operations;
6mod validity;
7
8use kernel::PARENT_KERNELS;
9use prost::Message;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_proto::dtype as pb;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17use vortex_utils::aliases::hash_set::HashSet;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::ExecutionResult;
22use crate::array::Array;
23use crate::array::ArrayId;
24use crate::array::ArrayParts;
25use crate::array::ArrayView;
26use crate::array::EmptyArrayData;
27use crate::array::VTable;
28use crate::arrays::variant::CORE_STORAGE_SLOT;
29use crate::arrays::variant::NUM_SLOTS;
30use crate::arrays::variant::SHREDDED_SLOT;
31use crate::arrays::variant::SLOT_NAMES;
32use crate::arrays::variant::compute::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::dtype::FieldName;
36use crate::dtype::FieldNames;
37use crate::dtype::Nullability;
38use crate::dtype::StructFields;
39use crate::scalar::Scalar;
40use crate::scalar::ScalarValue;
41use crate::serde::ArrayChildren;
42
43pub type VariantArray = Array<Variant>;
45
46#[derive(Clone, Debug)]
47pub struct Variant;
48
49#[derive(Clone, prost::Message)]
50struct VariantMetadataProto {
51 #[prost(message, optional, tag = "1")]
52 pub shredded_dtype: Option<pb::DType>,
53}
54
55impl VTable for Variant {
56 type TypedArrayData = EmptyArrayData;
57
58 type OperationsVTable = Self;
59
60 type ValidityVTable = Self;
61
62 fn id(&self) -> ArrayId {
63 static ID: CachedId = CachedId::new("vortex.variant");
64 *ID
65 }
66
67 fn validate(
68 &self,
69 _data: &Self::TypedArrayData,
70 dtype: &DType,
71 len: usize,
72 slots: &[Option<ArrayRef>],
73 ) -> VortexResult<()> {
74 vortex_ensure!(
75 slots.len() == NUM_SLOTS,
76 "VariantArray expects {NUM_SLOTS} slots, got {}",
77 slots.len()
78 );
79 vortex_ensure!(
80 slots[CORE_STORAGE_SLOT].is_some(),
81 "VariantArray core_storage slot must be present"
82 );
83 let core_storage = slots[CORE_STORAGE_SLOT]
84 .as_ref()
85 .vortex_expect("validated core_storage slot presence");
86 vortex_ensure!(
87 matches!(dtype, DType::Variant(_)),
88 "Expected Variant DType, got {dtype}"
89 );
90 vortex_ensure!(
91 core_storage.dtype() == dtype,
92 "VariantArray core_storage dtype {} does not match outer dtype {}",
93 core_storage.dtype(),
94 dtype
95 );
96 vortex_ensure!(
97 core_storage.len() == len,
98 "VariantArray core_storage length {} does not match outer length {}",
99 core_storage.len(),
100 len
101 );
102 if let Some(shredded) = slots[SHREDDED_SLOT].as_ref() {
103 vortex_ensure!(
104 shredded.len() == len,
105 "VariantArray shredded length {} does not match outer length {}",
106 shredded.len(),
107 len
108 );
109 }
110 Ok(())
111 }
112
113 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
114 0
115 }
116
117 fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
118 vortex_panic!("VariantArray buffer index {idx} out of bounds")
119 }
120
121 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
122 None
123 }
124
125 fn serialize(
126 array: ArrayView<'_, Self>,
127 _session: &VortexSession,
128 ) -> VortexResult<Option<Vec<u8>>> {
129 let shredded_dtype = array.slots()[SHREDDED_SLOT]
130 .as_ref()
131 .map(|shredded| shredded.dtype().try_into())
132 .transpose()?;
133 Ok(Some(
134 VariantMetadataProto { shredded_dtype }.encode_to_vec(),
135 ))
136 }
137
138 fn deserialize(
139 &self,
140 dtype: &DType,
141 len: usize,
142 metadata: &[u8],
143
144 buffers: &[BufferHandle],
145 children: &dyn ArrayChildren,
146 session: &VortexSession,
147 ) -> VortexResult<ArrayParts<Self>> {
148 vortex_ensure!(
149 buffers.is_empty(),
150 "VariantArray expects 0 buffers, got {}",
151 buffers.len()
152 );
153 let proto = VariantMetadataProto::decode(metadata)?;
154 let shredded_dtype = proto
155 .shredded_dtype
156 .as_ref()
157 .map(|dtype| DType::from_proto(dtype, session))
158 .transpose()?;
159 vortex_ensure!(matches!(dtype, DType::Variant(_)), "Expected Variant DType");
160 let expected_children = 1 + usize::from(shredded_dtype.is_some());
161 vortex_ensure!(
162 children.len() == expected_children,
163 "Expected {} children, got {}",
164 expected_children,
165 children.len(),
166 );
167 let core_storage = children.get(0, dtype, len)?;
168 let shredded = shredded_dtype
169 .map(|dtype| children.get(1, &dtype, len))
170 .transpose()?;
171 Ok(
172 ArrayParts::new(self.clone(), dtype.clone(), len, EmptyArrayData)
173 .with_slots(vec![Some(core_storage), shredded].into()),
174 )
175 }
176
177 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
178 match SLOT_NAMES.get(idx) {
179 Some(name) => (*name).to_string(),
180 None => vortex_panic!("VariantArray slot_name index {idx} out of bounds"),
181 }
182 }
183
184 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
185 Ok(ExecutionResult::done(array))
186 }
187
188 fn reduce_parent(
189 array: ArrayView<'_, Self>,
190 parent: &ArrayRef,
191 child_idx: usize,
192 ) -> VortexResult<Option<ArrayRef>> {
193 RULES.evaluate(array, parent, child_idx)
194 }
195
196 fn execute_parent(
197 array: ArrayView<'_, Self>,
198 parent: &ArrayRef,
199 child_idx: usize,
200 ctx: &mut ExecutionCtx,
201 ) -> VortexResult<Option<ArrayRef>> {
202 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
203 }
204}
205
206fn merge_typed_scalar_as_variant(
207 typed_scalar: Scalar,
208 fallback_scalar: Option<Scalar>,
209 dtype: &DType,
210) -> VortexResult<Scalar> {
211 let scalar = if typed_scalar.is_null() {
212 fallback_scalar.unwrap_or_else(|| Scalar::null(dtype.clone()))
213 } else if matches!(
214 typed_scalar.dtype(),
215 DType::List(..) | DType::FixedSizeList(..)
216 ) {
217 Scalar::variant(typed_list_as_variant_payload(typed_scalar)?)
218 } else if typed_scalar.dtype().is_struct() {
219 merge_typed_object_as_variant(typed_scalar, fallback_scalar)?
220 } else if typed_scalar.dtype().is_variant() {
221 typed_scalar
222 } else {
223 Scalar::variant(typed_scalar)
224 };
225
226 if scalar.dtype() == dtype {
227 Ok(scalar)
228 } else {
229 scalar.cast(dtype)
230 }
231}
232
233fn typed_list_as_variant_payload(typed_scalar: Scalar) -> VortexResult<Scalar> {
234 let list = typed_scalar.as_list();
235 let elements = list
236 .elements()
237 .unwrap_or_default()
238 .into_iter()
239 .map(|element| {
240 if element.dtype().is_variant() {
241 element
242 } else {
243 Scalar::variant(element)
244 }
245 })
246 .collect();
247 Ok(Scalar::list(
248 DType::Variant(Nullability::NonNullable),
249 elements,
250 Nullability::NonNullable,
251 ))
252}
253
254fn merge_typed_object_as_variant(
255 typed_scalar: Scalar,
256 fallback_scalar: Option<Scalar>,
257) -> VortexResult<Scalar> {
258 let fallback_inner = fallback_scalar
259 .as_ref()
260 .and_then(|scalar| scalar.as_variant().value())
261 .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null());
262 let Some(fallback_inner) = fallback_inner else {
263 return Ok(Scalar::variant(typed_scalar));
264 };
265
266 merge_struct_payload(&typed_scalar, Some(fallback_inner)).map(Scalar::variant)
267}
268
269fn merge_struct_payload(typed: &Scalar, raw: Option<&Scalar>) -> VortexResult<Scalar> {
270 let typed_struct = typed.as_struct();
271 let raw_struct = raw
272 .filter(|scalar| scalar.dtype().is_struct() && !scalar.is_null())
273 .map(Scalar::as_struct);
274 let mut present_typed_fields = HashSet::new();
275 let mut names = Vec::new();
276 let mut values = Vec::new();
277
278 for name in typed_struct.names().iter() {
279 let Some(typed_field) = typed_struct.field(name.as_ref()) else {
280 continue;
281 };
282 if typed_field.is_null() {
283 continue;
284 }
285
286 let raw_field = raw_struct.and_then(|raw_struct| raw_struct.field(name.as_ref()));
287 let raw_payload = raw_field.as_ref().and_then(|scalar| {
288 if scalar.dtype().is_variant() {
289 scalar.as_variant().value()
290 } else {
291 Some(scalar)
292 }
293 });
294 let field = if typed_field.dtype().is_struct()
295 && raw_payload.is_some_and(|raw| raw.dtype().is_struct() && !raw.is_null())
296 {
297 Scalar::variant(merge_struct_payload(&typed_field, raw_payload)?)
298 } else if typed_field.dtype().is_variant() {
299 typed_field.cast(&DType::Variant(Nullability::NonNullable))?
300 } else {
301 Scalar::variant(typed_field)
302 };
303
304 present_typed_fields.insert(name.as_ref().to_string());
305 names.push(FieldName::from(name.as_ref()));
306 values.push(field.into_value());
307 }
308
309 if let Some(raw_struct) = raw_struct {
310 for name in raw_struct.names().iter() {
311 if present_typed_fields.contains(name.as_ref()) {
312 continue;
313 }
314 let Some(raw_field) = raw_struct.field(name.as_ref()) else {
315 continue;
316 };
317 if raw_field.is_null() {
318 continue;
319 }
320 let raw_field = if raw_field.dtype().is_variant() {
321 raw_field.cast(&DType::Variant(Nullability::NonNullable))?
322 } else {
323 Scalar::variant(raw_field)
324 };
325 names.push(FieldName::from(name.as_ref()));
326 values.push(raw_field.into_value());
327 }
328 }
329
330 let fields = StructFields::new(
331 FieldNames::from(names),
332 vec![DType::Variant(Nullability::NonNullable); values.len()],
333 );
334 Scalar::try_new(
335 DType::Struct(fields, Nullability::NonNullable),
336 Some(ScalarValue::Tuple(values)),
337 )
338}
339
340#[cfg(test)]
341mod tests {}