vortex_array/arrays/masked/vtable/
mod.rs1mod array;
5mod canonical;
6mod operations;
7mod validity;
8
9use vortex_dtype::DType;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_vector::Vector;
15use vortex_vector::VectorOps;
16
17use crate::ArrayBufferVisitor;
18use crate::ArrayChildVisitor;
19use crate::ArrayRef;
20use crate::EmptyMetadata;
21use crate::VectorExecutor;
22use crate::arrays::masked::MaskedArray;
23use crate::buffer::BufferHandle;
24use crate::executor::ExecutionCtx;
25use crate::serde::ArrayChildren;
26use crate::validity::Validity;
27use crate::vtable;
28use crate::vtable::ArrayId;
29use crate::vtable::ArrayVTable;
30use crate::vtable::ArrayVTableExt;
31use crate::vtable::NotSupported;
32use crate::vtable::VTable;
33use crate::vtable::ValidityVTableFromValidityHelper;
34use crate::vtable::VisitorVTable;
35
36vtable!(Masked);
37
38#[derive(Debug)]
39pub struct MaskedVTable;
40
41impl VisitorVTable<MaskedVTable> for MaskedVTable {
42 fn visit_buffers(_array: &MaskedArray, _visitor: &mut dyn ArrayBufferVisitor) {}
43
44 fn visit_children(array: &MaskedArray, visitor: &mut dyn ArrayChildVisitor) {
45 visitor.visit_child("child", &array.child);
46 visitor.visit_validity(&array.validity, array.child.len());
47 }
48}
49
50impl VTable for MaskedVTable {
51 type Array = MaskedArray;
52
53 type Metadata = EmptyMetadata;
54
55 type ArrayVTable = Self;
56 type CanonicalVTable = Self;
57 type OperationsVTable = Self;
58 type ValidityVTable = ValidityVTableFromValidityHelper;
59 type VisitorVTable = Self;
60 type ComputeVTable = NotSupported;
61 type EncodeVTable = NotSupported;
62
63 fn id(&self) -> ArrayId {
64 ArrayId::new_ref("vortex.masked")
65 }
66
67 fn encoding(_array: &Self::Array) -> ArrayVTable {
68 MaskedVTable.as_vtable()
69 }
70
71 fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
72 Ok(EmptyMetadata)
73 }
74
75 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
76 Ok(Some(vec![]))
77 }
78
79 fn deserialize(_buffer: &[u8]) -> VortexResult<Self::Metadata> {
80 Ok(EmptyMetadata)
81 }
82
83 fn build(
84 &self,
85 dtype: &DType,
86 len: usize,
87 _metadata: &Self::Metadata,
88 buffers: &[BufferHandle],
89 children: &dyn ArrayChildren,
90 ) -> VortexResult<MaskedArray> {
91 if !buffers.is_empty() {
92 vortex_bail!("Expected 0 buffer, got {}", buffers.len());
93 }
94
95 let child = children.get(0, &dtype.as_nonnullable(), len)?;
96
97 let validity = if children.len() == 1 {
98 Validity::from(dtype.nullability())
99 } else if children.len() == 2 {
100 let validity = children.get(1, &Validity::DTYPE, len)?;
101 Validity::Array(validity)
102 } else {
103 vortex_bail!(
104 "`MaskedArray::build` expects 1 or 2 children, got {}",
105 children.len()
106 );
107 };
108
109 MaskedArray::try_new(child, validity)
110 }
111
112 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
113 let mut child = array.child().execute(ctx)?;
114 let validity_mask = array.validity_mask();
115
116 child.mask_validity(&validity_mask);
117 Ok(child)
118 }
119
120 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
121 vortex_ensure!(
122 children.len() == 1 || children.len() == 2,
123 "MaskedArray expects 1 or 2 children, got {}",
124 children.len()
125 );
126
127 let mut iter = children.into_iter();
128 let child = iter
129 .next()
130 .vortex_expect("children length already validated");
131 let validity = if let Some(validity_array) = iter.next() {
132 Validity::Array(validity_array)
133 } else {
134 Validity::from(array.dtype.nullability())
135 };
136
137 let new_array = MaskedArray::try_new(child, validity)?;
138 *array = new_array;
139 Ok(())
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use rstest::rstest;
146 use vortex_buffer::ByteBufferMut;
147
148 use crate::ArrayContext;
149 use crate::IntoArray;
150 use crate::arrays::MaskedArray;
151 use crate::arrays::MaskedVTable;
152 use crate::arrays::PrimitiveArray;
153 use crate::serde::ArrayParts;
154 use crate::serde::SerializeOptions;
155 use crate::validity::Validity;
156 use crate::vtable::ArrayVTableExt;
157
158 #[rstest]
159 #[case(
160 MaskedArray::try_new(
161 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
162 Validity::AllValid
163 ).unwrap()
164 )]
165 #[case(
166 MaskedArray::try_new(
167 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
168 Validity::from_iter([true, true, false, true, false])
169 ).unwrap()
170 )]
171 #[case(
172 MaskedArray::try_new(
173 PrimitiveArray::from_iter(0..100).into_array(),
174 Validity::from_iter((0..100).map(|i| i % 3 != 0))
175 ).unwrap()
176 )]
177 fn test_serde_roundtrip(#[case] array: MaskedArray) {
178 let dtype = array.dtype().clone();
179 let len = array.len();
180 let ctx = ArrayContext::empty().with(MaskedVTable.as_vtable());
181
182 let serialized = array
183 .to_array()
184 .serialize(&ctx, &SerializeOptions::default())
185 .unwrap();
186
187 let mut concat = ByteBufferMut::empty();
189 for buf in serialized {
190 concat.extend_from_slice(buf.as_ref());
191 }
192 let concat = concat.freeze();
193
194 let parts = ArrayParts::try_from(concat).unwrap();
195 let decoded = parts.decode(&ctx, &dtype, len).unwrap();
196
197 assert!(decoded.is::<MaskedVTable>());
198 assert_eq!(
199 array.as_ref().display_values().to_string(),
200 decoded.display_values().to_string()
201 );
202 }
203}