vortex_layout/layouts/flat/
writer.rs1use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::ArrayContext;
7use vortex_array::DynArray;
8use vortex_array::dtype::DType;
9use vortex_array::expr::stats::Precision;
10use vortex_array::expr::stats::Stat;
11use vortex_array::expr::stats::StatsProvider;
12use vortex_array::normalize::NormalizeOptions;
13use vortex_array::normalize::Operation;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar::ScalarTruncation;
16use vortex_array::scalar::lower_bound;
17use vortex_array::scalar::upper_bound;
18use vortex_array::serde::SerializeOptions;
19use vortex_array::session::ArrayRegistry;
20use vortex_array::stats::StatsSetRef;
21use vortex_buffer::BufferString;
22use vortex_buffer::ByteBuffer;
23use vortex_error::VortexExpect;
24use vortex_error::VortexResult;
25use vortex_error::vortex_bail;
26use vortex_io::runtime::Handle;
27use vortex_session::registry::ReadContext;
28
29use crate::IntoLayout;
30use crate::LayoutRef;
31use crate::LayoutStrategy;
32use crate::layouts::flat::FlatLayout;
33use crate::layouts::flat::flat_layout_inline_array_node;
34use crate::segments::SegmentSinkRef;
35use crate::sequence::SendableSequentialStream;
36use crate::sequence::SequencePointer;
37
38#[derive(Clone)]
39pub struct FlatLayoutStrategy {
40 pub include_padding: bool,
42 pub max_variable_length_statistics_size: usize,
44 pub allowed_encodings: Option<ArrayRegistry>,
47}
48
49impl Default for FlatLayoutStrategy {
50 fn default() -> Self {
51 Self {
52 include_padding: true,
53 max_variable_length_statistics_size: 64,
54 allowed_encodings: None,
55 }
56 }
57}
58
59impl FlatLayoutStrategy {
60 pub fn with_include_padding(mut self, include_padding: bool) -> Self {
62 self.include_padding = include_padding;
63 self
64 }
65
66 pub fn with_max_variable_length_statistics_size(mut self, size: usize) -> Self {
68 self.max_variable_length_statistics_size = size;
69 self
70 }
71
72 pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self {
74 self.allowed_encodings = Some(allow_encodings);
75 self
76 }
77}
78
79fn truncate_scalar_stat<F: Fn(Scalar) -> Option<(Scalar, bool)>>(
80 statistics: StatsSetRef<'_>,
81 stat: Stat,
82 truncation: F,
83) {
84 if let Some(sv) = statistics.get(stat) {
85 if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) {
86 if truncated && let Some(v) = truncated_value.into_value() {
87 statistics.set(stat, Precision::Inexact(v));
88 }
89 } else {
90 statistics.clear(stat)
91 }
92 }
93}
94
95#[async_trait]
96impl LayoutStrategy for FlatLayoutStrategy {
97 async fn write_stream(
98 &self,
99 ctx: ArrayContext,
100 segment_sink: SegmentSinkRef,
101 mut stream: SendableSequentialStream,
102 _eof: SequencePointer,
103 _handle: Handle,
104 ) -> VortexResult<LayoutRef> {
105 let ctx = ctx.clone();
106 let Some(chunk) = stream.next().await else {
107 vortex_bail!("flat layout needs a single chunk");
108 };
109 let (sequence_id, chunk) = chunk?;
110
111 let row_count = chunk.len() as u64;
112
113 match chunk.dtype() {
114 DType::Utf8(n) => {
115 truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
116 lower_bound(
117 BufferString::from_scalar(v)
118 .vortex_expect("utf8 scalar must be a BufferString"),
119 self.max_variable_length_statistics_size,
120 *n,
121 )
122 });
123 truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
124 upper_bound(
125 BufferString::from_scalar(v)
126 .vortex_expect("utf8 scalar must be a BufferString"),
127 self.max_variable_length_statistics_size,
128 *n,
129 )
130 });
131 }
132 DType::Binary(n) => {
133 truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
134 lower_bound(
135 ByteBuffer::from_scalar(v)
136 .vortex_expect("binary scalar must be a ByteBuffer"),
137 self.max_variable_length_statistics_size,
138 *n,
139 )
140 });
141 truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
142 upper_bound(
143 ByteBuffer::from_scalar(v)
144 .vortex_expect("binary scalar must be a ByteBuffer"),
145 self.max_variable_length_statistics_size,
146 *n,
147 )
148 });
149 }
150 _ => {}
151 }
152
153 let chunk = if let Some(allowed) = &self.allowed_encodings {
154 chunk.normalize(&mut NormalizeOptions {
155 allowed,
156 operation: Operation::Error,
157 })?
158 } else {
159 chunk
160 };
161
162 let buffers = chunk.serialize(
163 &ctx,
164 &SerializeOptions {
165 offset: 0,
166 include_padding: self.include_padding,
167 },
168 )?;
169 assert!(buffers.len() >= 2);
171 let array_node =
172 flat_layout_inline_array_node().then(|| buffers[buffers.len() - 2].clone());
173 let segment_id = segment_sink.write(sequence_id, buffers).await?;
174
175 let None = stream.next().await else {
176 vortex_bail!("flat layout received stream with more than a single chunk");
177 };
178 Ok(FlatLayout::new_with_metadata(
179 row_count,
180 stream.dtype().clone(),
181 segment_id,
182 ReadContext::new(ctx.to_ids()),
183 array_node,
184 )
185 .into_layout())
186 }
187
188 fn buffered_bytes(&self) -> u64 {
189 0
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use std::sync::Arc;
197
198 use vortex_array::ArrayContext;
199 use vortex_array::ArrayRef;
200 use vortex_array::DynArray;
201 use vortex_array::IntoArray;
202 use vortex_array::MaskFuture;
203 use vortex_array::ToCanonical;
204 use vortex_array::arrays::BoolArray;
205 use vortex_array::arrays::Dict;
206 use vortex_array::arrays::DictArray;
207 use vortex_array::arrays::Primitive;
208 use vortex_array::arrays::PrimitiveArray;
209 use vortex_array::arrays::StructArray;
210 use vortex_array::builders::ArrayBuilder;
211 use vortex_array::builders::VarBinViewBuilder;
212 use vortex_array::dtype::DType;
213 use vortex_array::dtype::FieldName;
214 use vortex_array::dtype::FieldNames;
215 use vortex_array::dtype::Nullability;
216 use vortex_array::expr::root;
217 use vortex_array::expr::stats::Precision;
218 use vortex_array::expr::stats::Stat;
219 use vortex_array::expr::stats::StatsProviderExt;
220 use vortex_array::session::ArrayRegistry;
221 use vortex_array::validity::Validity;
222 use vortex_array::vtable::DynVTableRef;
223 use vortex_array::vtable::VTable;
224 use vortex_buffer::BitBufferMut;
225 use vortex_buffer::buffer;
226 use vortex_error::VortexExpect;
227 use vortex_error::VortexResult;
228 use vortex_io::runtime::single::block_on;
229 use vortex_mask::AllOr;
230 use vortex_mask::Mask;
231
232 use crate::LayoutStrategy;
233 use crate::layouts::flat::writer::FlatLayoutStrategy;
234 use crate::segments::TestSegments;
235 use crate::sequence::SequenceId;
236 use crate::sequence::SequentialArrayStreamExt;
237 use crate::test::SESSION;
238
239 #[should_panic]
242 #[test]
243 fn flat_stats() {
244 block_on(|handle| async {
245 let ctx = ArrayContext::empty();
246 let segments = Arc::new(TestSegments::default());
247 let (ptr, eof) = SequenceId::root().split();
248 let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
249 let layout = FlatLayoutStrategy::default()
250 .write_stream(
251 ctx,
252 segments.clone(),
253 array.to_array_stream().sequenced(ptr),
254 eof,
255 handle,
256 )
257 .await
258 .unwrap();
259
260 let result = layout
261 .new_reader("".into(), segments, &SESSION)
262 .unwrap()
263 .projection_evaluation(
264 &(0..layout.row_count()),
265 &root(),
266 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
267 )
268 .unwrap()
269 .await
270 .unwrap();
271
272 assert_eq!(
273 result.statistics().get_as::<bool>(Stat::IsSorted),
274 Some(Precision::Exact(true))
275 );
276 })
277 }
278
279 #[test]
280 fn truncates_variable_size_stats() {
281 block_on(|handle| async {
282 let ctx = ArrayContext::empty();
283 let segments = Arc::new(TestSegments::default());
284 let (ptr, eof) = SequenceId::root().split();
285 let mut builder =
286 VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
287 builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
288 builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
289 let array = builder.finish();
290 array.statistics().set_iter(
291 array
292 .statistics()
293 .compute_all(&Stat::all().collect::<Vec<_>>())
294 .vortex_expect("stats computation should succeed for test array")
295 .into_iter(),
296 );
297
298 let layout = FlatLayoutStrategy::default()
299 .write_stream(
300 ctx,
301 segments.clone(),
302 array.to_array_stream().sequenced(ptr),
303 eof,
304 handle,
305 )
306 .await
307 .unwrap();
308
309 let result = layout
310 .new_reader("".into(), segments, &SESSION)
311 .unwrap()
312 .projection_evaluation(
313 &(0..layout.row_count()),
314 &root(),
315 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
316 )
317 .unwrap()
318 .await
319 .unwrap();
320
321 assert_eq!(
322 result.statistics().get_as::<String>(Stat::Min),
323 Some(Precision::Inexact(
325 "Another string that's meant to be smaller than the previous valu".to_string()
327 ))
328 );
329 assert_eq!(
330 result.statistics().get_as::<String>(Stat::Max),
331 Some(Precision::Inexact(
332 "Long value to test that the statistics are actually truncated, j".to_string()
333 ))
334 );
335 })
336 }
337
338 #[test]
339 fn struct_array_round_trip() {
340 block_on(|handle| async {
341 let mut validity_builder = BitBufferMut::with_capacity(2);
342 validity_builder.append(true);
343 validity_builder.append(false);
344 let validity_boolean_buffer = validity_builder.freeze();
345 let validity = Validity::Array(
346 BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
347 );
348 let array = StructArray::try_new(
349 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
350 vec![
351 buffer![1_u64, 2].into_array(),
352 buffer![3_u64, 4].into_array(),
353 ],
354 2,
355 validity,
356 )
357 .unwrap();
358
359 let ctx = ArrayContext::empty();
360
361 let (layout, segments) = {
363 let segments = Arc::new(TestSegments::default());
364 let (ptr, eof) = SequenceId::root().split();
365 let layout = FlatLayoutStrategy::default()
366 .write_stream(
367 ctx,
368 segments.clone(),
369 array.to_array_stream().sequenced(ptr),
370 eof,
371 handle,
372 )
373 .await
374 .unwrap();
375
376 (layout, segments)
377 };
378
379 let result: ArrayRef = layout
381 .new_reader("".into(), segments, &SESSION)
382 .unwrap()
383 .projection_evaluation(
384 &(0..layout.row_count()),
385 &root(),
386 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
387 )
388 .unwrap()
389 .await
390 .unwrap();
391
392 assert_eq!(
393 result.validity_mask().unwrap().bit_buffer(),
394 AllOr::Some(&validity_boolean_buffer)
395 );
396 assert_eq!(
397 result
398 .to_struct()
399 .unmasked_field_by_name("a")
400 .unwrap()
401 .to_primitive()
402 .as_slice::<u64>(),
403 &[1, 2]
404 );
405 assert_eq!(
406 result
407 .to_struct()
408 .unmasked_field_by_name("b")
409 .unwrap()
410 .to_primitive()
411 .as_slice::<u64>(),
412 &[3, 4]
413 );
414 })
415 }
416
417 #[test]
418 fn flat_invalid_array_fails() -> VortexResult<()> {
419 block_on(|handle| async {
420 let prim: PrimitiveArray = (0..10).collect();
421 let filter = prim.filter(Mask::from_indices(10, vec![2, 3]))?;
422
423 let ctx = ArrayContext::empty();
424
425 let (layout, _segments) = {
427 let segments = Arc::new(TestSegments::default());
428 let (ptr, eof) = SequenceId::root().split();
429 let allowed = ArrayRegistry::default();
431 allowed.register(Primitive::ID, Arc::new(Primitive) as DynVTableRef);
432 let layout = FlatLayoutStrategy::default()
433 .with_allow_encodings(allowed)
434 .write_stream(
435 ctx,
436 segments.clone(),
437 filter.to_array_stream().sequenced(ptr),
438 eof,
439 handle,
440 )
441 .await;
442
443 (layout, segments)
444 };
445
446 let err = layout.expect_err("expected error");
447 assert!(
448 err.to_string()
449 .contains("normalize forbids encoding (vortex.filter)"),
450 "unexpected error: {err}"
451 );
452
453 Ok(())
454 })
455 }
456
457 #[test]
458 fn flat_valid_array_writes() -> VortexResult<()> {
459 block_on(|handle| async {
460 let codes: PrimitiveArray = (0u32..10).collect();
461 let values: PrimitiveArray = (0..10).collect();
462 let dict = DictArray::new(codes.into_array(), values.into_array());
463
464 let ctx = ArrayContext::empty();
465
466 let (layout, _segments) = {
468 let segments = Arc::new(TestSegments::default());
469 let (ptr, eof) = SequenceId::root().split();
470 let allowed = ArrayRegistry::default();
472 allowed.register(Primitive.id(), Arc::new(Primitive) as DynVTableRef);
473 allowed.register(Dict.id(), Arc::new(Dict) as DynVTableRef);
474 let layout = FlatLayoutStrategy::default()
475 .with_allow_encodings(allowed)
476 .write_stream(
477 ctx,
478 segments.clone(),
479 dict.to_array_stream().sequenced(ptr),
480 eof,
481 handle,
482 )
483 .await;
484
485 (layout, segments)
486 };
487
488 assert!(layout.is_ok());
489
490 Ok(())
491 })
492 }
493}