vortex_layout/layouts/flat/
writer.rs1use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::Array;
7use vortex_array::ArrayContext;
8use vortex_array::expr::stats::Precision;
9use vortex_array::expr::stats::Stat;
10use vortex_array::expr::stats::StatsProvider;
11use vortex_array::normalize::NormalizeOptions;
12use vortex_array::normalize::Operation;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar::ScalarTruncation;
15use vortex_array::scalar::lower_bound;
16use vortex_array::scalar::upper_bound;
17use vortex_array::serde::SerializeOptions;
18use vortex_array::session::ArrayRegistry;
19use vortex_array::stats::StatsSetRef;
20use vortex_buffer::BufferString;
21use vortex_buffer::ByteBuffer;
22use vortex_dtype::DType;
23use vortex_error::VortexExpect;
24use vortex_error::VortexResult;
25use vortex_error::vortex_bail;
26use vortex_io::runtime::Handle;
27
28use crate::IntoLayout;
29use crate::LayoutRef;
30use crate::LayoutStrategy;
31use crate::layouts::flat::FlatLayout;
32use crate::layouts::flat::flat_layout_inline_array_node;
33use crate::segments::SegmentSinkRef;
34use crate::sequence::SendableSequentialStream;
35use crate::sequence::SequencePointer;
36
37#[derive(Clone)]
38pub struct FlatLayoutStrategy {
39 pub include_padding: bool,
41 pub max_variable_length_statistics_size: usize,
43 pub allowed_encodings: Option<ArrayRegistry>,
46}
47
48impl Default for FlatLayoutStrategy {
49 fn default() -> Self {
50 Self {
51 include_padding: true,
52 max_variable_length_statistics_size: 64,
53 allowed_encodings: None,
54 }
55 }
56}
57
58impl FlatLayoutStrategy {
59 pub fn with_include_padding(mut self, include_padding: bool) -> Self {
61 self.include_padding = include_padding;
62 self
63 }
64
65 pub fn with_max_variable_length_statistics_size(mut self, size: usize) -> Self {
67 self.max_variable_length_statistics_size = size;
68 self
69 }
70
71 pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self {
73 self.allowed_encodings = Some(allow_encodings);
74 self
75 }
76}
77
78fn truncate_scalar_stat<F: Fn(Scalar) -> Option<(Scalar, bool)>>(
79 statistics: StatsSetRef<'_>,
80 stat: Stat,
81 truncation: F,
82) {
83 if let Some(sv) = statistics.get(stat) {
84 if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) {
85 if truncated && let Some(v) = truncated_value.into_value() {
86 statistics.set(stat, Precision::Inexact(v));
87 }
88 } else {
89 statistics.clear(stat)
90 }
91 }
92}
93
94#[async_trait]
95impl LayoutStrategy for FlatLayoutStrategy {
96 async fn write_stream(
97 &self,
98 ctx: ArrayContext,
99 segment_sink: SegmentSinkRef,
100 mut stream: SendableSequentialStream,
101 _eof: SequencePointer,
102 _handle: Handle,
103 ) -> VortexResult<LayoutRef> {
104 let ctx = ctx.clone();
105 let Some(chunk) = stream.next().await else {
106 vortex_bail!("flat layout needs a single chunk");
107 };
108 let (sequence_id, chunk) = chunk?;
109
110 let row_count = chunk.len() as u64;
111
112 match chunk.dtype() {
113 DType::Utf8(n) => {
114 truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
115 lower_bound(
116 BufferString::from_scalar(v)
117 .vortex_expect("utf8 scalar must be a BufferString"),
118 self.max_variable_length_statistics_size,
119 *n,
120 )
121 });
122 truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
123 upper_bound(
124 BufferString::from_scalar(v)
125 .vortex_expect("utf8 scalar must be a BufferString"),
126 self.max_variable_length_statistics_size,
127 *n,
128 )
129 });
130 }
131 DType::Binary(n) => {
132 truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
133 lower_bound(
134 ByteBuffer::from_scalar(v)
135 .vortex_expect("binary scalar must be a ByteBuffer"),
136 self.max_variable_length_statistics_size,
137 *n,
138 )
139 });
140 truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
141 upper_bound(
142 ByteBuffer::from_scalar(v)
143 .vortex_expect("binary scalar must be a ByteBuffer"),
144 self.max_variable_length_statistics_size,
145 *n,
146 )
147 });
148 }
149 _ => {}
150 }
151
152 let chunk = if let Some(allowed) = &self.allowed_encodings {
153 chunk.normalize(&mut NormalizeOptions {
154 allowed,
155 operation: Operation::Error,
156 })?
157 } else {
158 chunk
159 };
160
161 let buffers = chunk.serialize(
162 &ctx,
163 &SerializeOptions {
164 offset: 0,
165 include_padding: self.include_padding,
166 },
167 )?;
168 assert!(buffers.len() >= 2);
170 let array_node =
171 flat_layout_inline_array_node().then(|| buffers[buffers.len() - 2].clone());
172 let segment_id = segment_sink.write(sequence_id, buffers).await?;
173
174 let None = stream.next().await else {
175 vortex_bail!("flat layout received stream with more than a single chunk");
176 };
177 Ok(FlatLayout::new_with_metadata(
178 row_count,
179 stream.dtype().clone(),
180 segment_id,
181 ctx.clone(),
182 array_node,
183 )
184 .into_layout())
185 }
186
187 fn buffered_bytes(&self) -> u64 {
188 0
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use std::sync::Arc;
196
197 use vortex_array::Array;
198 use vortex_array::ArrayContext;
199 use vortex_array::ArrayRef;
200 use vortex_array::IntoArray;
201 use vortex_array::MaskFuture;
202 use vortex_array::ToCanonical;
203 use vortex_array::arrays::BoolArray;
204 use vortex_array::arrays::DictArray;
205 use vortex_array::arrays::DictVTable;
206 use vortex_array::arrays::PrimitiveArray;
207 use vortex_array::arrays::PrimitiveVTable;
208 use vortex_array::arrays::StructArray;
209 use vortex_array::builders::ArrayBuilder;
210 use vortex_array::builders::VarBinViewBuilder;
211 use vortex_array::expr::root;
212 use vortex_array::expr::stats::Precision;
213 use vortex_array::expr::stats::Stat;
214 use vortex_array::expr::stats::StatsProviderExt;
215 use vortex_array::session::ArrayRegistry;
216 use vortex_array::validity::Validity;
217 use vortex_buffer::BitBufferMut;
218 use vortex_buffer::buffer;
219 use vortex_dtype::DType;
220 use vortex_dtype::FieldName;
221 use vortex_dtype::FieldNames;
222 use vortex_dtype::Nullability;
223 use vortex_error::VortexExpect;
224 use vortex_error::VortexResult;
225 use vortex_io::runtime::single::block_on;
226 use vortex_mask::AllOr;
227 use vortex_mask::Mask;
228
229 use crate::LayoutStrategy;
230 use crate::layouts::flat::writer::FlatLayoutStrategy;
231 use crate::segments::TestSegments;
232 use crate::sequence::SequenceId;
233 use crate::sequence::SequentialArrayStreamExt;
234 use crate::test::SESSION;
235
236 #[should_panic]
239 #[test]
240 fn flat_stats() {
241 block_on(|handle| async {
242 let ctx = ArrayContext::empty();
243 let segments = Arc::new(TestSegments::default());
244 let (ptr, eof) = SequenceId::root().split();
245 let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
246 let layout = FlatLayoutStrategy::default()
247 .write_stream(
248 ctx,
249 segments.clone(),
250 array.to_array_stream().sequenced(ptr),
251 eof,
252 handle,
253 )
254 .await
255 .unwrap();
256
257 let result = layout
258 .new_reader("".into(), segments, &SESSION)
259 .unwrap()
260 .projection_evaluation(
261 &(0..layout.row_count()),
262 &root(),
263 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
264 )
265 .unwrap()
266 .await
267 .unwrap();
268
269 assert_eq!(
270 result.statistics().get_as::<bool>(Stat::IsSorted),
271 Some(Precision::Exact(true))
272 );
273 })
274 }
275
276 #[test]
277 fn truncates_variable_size_stats() {
278 block_on(|handle| async {
279 let ctx = ArrayContext::empty();
280 let segments = Arc::new(TestSegments::default());
281 let (ptr, eof) = SequenceId::root().split();
282 let mut builder =
283 VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
284 builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
285 builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
286 let array = builder.finish();
287 array.statistics().set_iter(
288 array
289 .statistics()
290 .compute_all(&Stat::all().collect::<Vec<_>>())
291 .vortex_expect("stats computation should succeed for test array")
292 .into_iter(),
293 );
294
295 let layout = FlatLayoutStrategy::default()
296 .write_stream(
297 ctx,
298 segments.clone(),
299 array.to_array_stream().sequenced(ptr),
300 eof,
301 handle,
302 )
303 .await
304 .unwrap();
305
306 let result = layout
307 .new_reader("".into(), segments, &SESSION)
308 .unwrap()
309 .projection_evaluation(
310 &(0..layout.row_count()),
311 &root(),
312 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
313 )
314 .unwrap()
315 .await
316 .unwrap();
317
318 assert_eq!(
319 result.statistics().get_as::<String>(Stat::Min),
320 Some(Precision::Inexact(
322 "Another string that's meant to be smaller than the previous valu".to_string()
324 ))
325 );
326 assert_eq!(
327 result.statistics().get_as::<String>(Stat::Max),
328 Some(Precision::Inexact(
329 "Long value to test that the statistics are actually truncated, j".to_string()
330 ))
331 );
332 })
333 }
334
335 #[test]
336 fn struct_array_round_trip() {
337 block_on(|handle| async {
338 let mut validity_builder = BitBufferMut::with_capacity(2);
339 validity_builder.append(true);
340 validity_builder.append(false);
341 let validity_boolean_buffer = validity_builder.freeze();
342 let validity = Validity::Array(
343 BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
344 );
345 let array = StructArray::try_new(
346 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
347 vec![
348 buffer![1_u64, 2].into_array(),
349 buffer![3_u64, 4].into_array(),
350 ],
351 2,
352 validity,
353 )
354 .unwrap();
355
356 let ctx = ArrayContext::empty();
357
358 let (layout, segments) = {
360 let segments = Arc::new(TestSegments::default());
361 let (ptr, eof) = SequenceId::root().split();
362 let layout = FlatLayoutStrategy::default()
363 .write_stream(
364 ctx,
365 segments.clone(),
366 array.to_array_stream().sequenced(ptr),
367 eof,
368 handle,
369 )
370 .await
371 .unwrap();
372
373 (layout, segments)
374 };
375
376 let result: ArrayRef = layout
378 .new_reader("".into(), segments, &SESSION)
379 .unwrap()
380 .projection_evaluation(
381 &(0..layout.row_count()),
382 &root(),
383 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
384 )
385 .unwrap()
386 .await
387 .unwrap();
388
389 assert_eq!(
390 result.validity_mask().unwrap().bit_buffer(),
391 AllOr::Some(&validity_boolean_buffer)
392 );
393 assert_eq!(
394 result
395 .to_struct()
396 .unmasked_field_by_name("a")
397 .unwrap()
398 .to_primitive()
399 .as_slice::<u64>(),
400 &[1, 2]
401 );
402 assert_eq!(
403 result
404 .to_struct()
405 .unmasked_field_by_name("b")
406 .unwrap()
407 .to_primitive()
408 .as_slice::<u64>(),
409 &[3, 4]
410 );
411 })
412 }
413
414 #[test]
415 fn flat_invalid_array_fails() -> VortexResult<()> {
416 block_on(|handle| async {
417 let prim: PrimitiveArray = (0..10).collect();
418 let filter = prim.filter(Mask::from_indices(10, vec![2, 3]))?;
419
420 let ctx = ArrayContext::empty();
421
422 let (layout, _segments) = {
424 let segments = Arc::new(TestSegments::default());
425 let (ptr, eof) = SequenceId::root().split();
426 let allowed = ArrayRegistry::default();
428 allowed.register(PrimitiveVTable::ID, PrimitiveVTable);
429 let layout = FlatLayoutStrategy::default()
430 .with_allow_encodings(allowed)
431 .write_stream(
432 ctx,
433 segments.clone(),
434 filter.to_array_stream().sequenced(ptr),
435 eof,
436 handle,
437 )
438 .await;
439
440 (layout, segments)
441 };
442
443 let err = layout.expect_err("expected error");
444 assert!(
445 err.to_string()
446 .contains("normalize forbids encoding (vortex.filter)"),
447 "unexpected error: {err}"
448 );
449
450 Ok(())
451 })
452 }
453
454 #[test]
455 fn flat_valid_array_writes() -> VortexResult<()> {
456 block_on(|handle| async {
457 let codes: PrimitiveArray = (0u32..10).collect();
458 let values: PrimitiveArray = (0..10).collect();
459 let dict = DictArray::new(codes.into_array(), values.into_array());
460
461 let ctx = ArrayContext::empty();
462
463 let (layout, _segments) = {
465 let segments = Arc::new(TestSegments::default());
466 let (ptr, eof) = SequenceId::root().split();
467 let allowed = ArrayRegistry::default();
469 allowed.register(PrimitiveVTable::ID, PrimitiveVTable);
470 allowed.register(DictVTable::ID, DictVTable);
471 let layout = FlatLayoutStrategy::default()
472 .with_allow_encodings(allowed)
473 .write_stream(
474 ctx,
475 segments.clone(),
476 dict.to_array_stream().sequenced(ptr),
477 eof,
478 handle,
479 )
480 .await;
481
482 (layout, segments)
483 };
484
485 assert!(layout.is_ok());
486
487 Ok(())
488 })
489 }
490}