vortex_layout/layouts/flat/
writer.rs1use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::ArrayContext;
7use vortex_array::ArrayId;
8use vortex_array::dtype::DType;
9use vortex_array::expr::stats::Precision;
10use vortex_array::expr::stats::Stat;
11use vortex_array::expr::stats::StatsProvider;
12use vortex_array::normalize::NormalizeOptions;
13use vortex_array::normalize::Operation;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar::ScalarTruncation;
16use vortex_array::scalar::lower_bound;
17use vortex_array::scalar::upper_bound;
18use vortex_array::serde::SerializeOptions;
19use vortex_array::stats::StatsSetRef;
20use vortex_buffer::BufferString;
21use vortex_buffer::ByteBuffer;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_bail;
25use vortex_session::VortexSession;
26use vortex_session::registry::ReadContext;
27use vortex_utils::aliases::hash_set::HashSet;
28
29use crate::IntoLayout;
30use crate::LayoutRef;
31use crate::LayoutStrategy;
32use crate::layouts::flat::FlatLayout;
33use crate::layouts::flat::flat_layout_inline_array_node;
34use crate::segments::SegmentSinkRef;
35use crate::sequence::SendableSequentialStream;
36use crate::sequence::SequencePointer;
37
38#[derive(Clone)]
39pub struct FlatLayoutStrategy {
40 pub include_padding: bool,
42 pub max_variable_length_statistics_size: usize,
44 pub allowed_encodings: Option<HashSet<ArrayId>>,
47}
48
49impl Default for FlatLayoutStrategy {
50 fn default() -> Self {
51 Self {
52 include_padding: true,
53 max_variable_length_statistics_size: 64,
54 allowed_encodings: None,
55 }
56 }
57}
58
59impl FlatLayoutStrategy {
60 pub fn with_include_padding(mut self, include_padding: bool) -> Self {
62 self.include_padding = include_padding;
63 self
64 }
65
66 pub fn with_max_variable_length_statistics_size(mut self, size: usize) -> Self {
68 self.max_variable_length_statistics_size = size;
69 self
70 }
71
72 pub fn with_allow_encodings(mut self, allow_encodings: HashSet<ArrayId>) -> Self {
74 self.allowed_encodings = Some(allow_encodings);
75 self
76 }
77}
78
79fn truncate_scalar_stat<F: Fn(Scalar) -> Option<(Scalar, bool)>>(
80 statistics: StatsSetRef<'_>,
81 stat: Stat,
82 truncation: F,
83) {
84 if let Some(sv) = statistics.get(stat) {
85 if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) {
86 if truncated && let Some(v) = truncated_value.into_value() {
87 statistics.set(stat, Precision::Inexact(v));
88 }
89 } else {
90 statistics.clear(stat)
91 }
92 }
93}
94
95#[async_trait]
96impl LayoutStrategy for FlatLayoutStrategy {
97 async fn write_stream(
98 &self,
99 ctx: ArrayContext,
100 segment_sink: SegmentSinkRef,
101 mut stream: SendableSequentialStream,
102 _eof: SequencePointer,
103 session: &VortexSession,
104 ) -> VortexResult<LayoutRef> {
105 let ctx = ctx.clone();
106 let Some(chunk) = stream.next().await else {
107 vortex_bail!("flat layout needs a single chunk");
108 };
109 let (sequence_id, chunk) = chunk?;
110
111 let row_count = chunk.len() as u64;
112
113 match chunk.dtype() {
114 DType::Utf8(n) => {
115 truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
116 lower_bound(
117 BufferString::from_scalar(v)
118 .vortex_expect("utf8 scalar must be a BufferString"),
119 self.max_variable_length_statistics_size,
120 *n,
121 )
122 });
123 truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
124 upper_bound(
125 BufferString::from_scalar(v)
126 .vortex_expect("utf8 scalar must be a BufferString"),
127 self.max_variable_length_statistics_size,
128 *n,
129 )
130 });
131 }
132 DType::Binary(n) => {
133 truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
134 lower_bound(
135 ByteBuffer::from_scalar(v)
136 .vortex_expect("binary scalar must be a ByteBuffer"),
137 self.max_variable_length_statistics_size,
138 *n,
139 )
140 });
141 truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
142 upper_bound(
143 ByteBuffer::from_scalar(v)
144 .vortex_expect("binary scalar must be a ByteBuffer"),
145 self.max_variable_length_statistics_size,
146 *n,
147 )
148 });
149 }
150 _ => {}
151 }
152
153 let chunk = if let Some(allowed) = &self.allowed_encodings {
154 chunk.normalize(&mut NormalizeOptions {
155 allowed,
156 operation: Operation::Error,
157 })?
158 } else {
159 chunk
160 };
161
162 let buffers = chunk.serialize(
163 &ctx,
164 session,
165 &SerializeOptions {
166 offset: 0,
167 include_padding: self.include_padding,
168 },
169 )?;
170 assert!(buffers.len() >= 2);
172 let array_node =
173 flat_layout_inline_array_node().then(|| buffers[buffers.len() - 2].clone());
174 let segment_id = segment_sink.write(sequence_id, buffers).await?;
175
176 let None = stream.next().await else {
177 vortex_bail!("flat layout received stream with more than a single chunk");
178 };
179 Ok(FlatLayout::new_with_metadata(
180 row_count,
181 stream.dtype().clone(),
182 segment_id,
183 ReadContext::new(ctx.to_ids()),
184 array_node,
185 )
186 .into_layout())
187 }
188
189 fn buffered_bytes(&self) -> u64 {
190 0
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use std::sync::Arc;
198
199 use vortex_array::ArrayContext;
200 use vortex_array::ArrayRef;
201 use vortex_array::IntoArray;
202 use vortex_array::LEGACY_SESSION;
203 use vortex_array::MaskFuture;
204 use vortex_array::VortexSessionExecute;
205 use vortex_array::arrays::BoolArray;
206 use vortex_array::arrays::Dict;
207 use vortex_array::arrays::DictArray;
208 use vortex_array::arrays::PrimitiveArray;
209 use vortex_array::arrays::StructArray;
210 use vortex_array::arrays::struct_::StructArrayExt;
211 use vortex_array::builders::ArrayBuilder;
212 use vortex_array::builders::VarBinViewBuilder;
213 use vortex_array::dtype::DType;
214 use vortex_array::dtype::FieldName;
215 use vortex_array::dtype::FieldNames;
216 use vortex_array::dtype::Nullability;
217 use vortex_array::expr::root;
218 use vortex_array::expr::stats::Precision;
219 use vortex_array::expr::stats::Stat;
220 use vortex_array::expr::stats::StatsProviderExt;
221 use vortex_array::validity::Validity;
222 use vortex_array::vtable::VTable;
223 use vortex_buffer::BitBufferMut;
224 use vortex_buffer::buffer;
225 use vortex_error::VortexExpect;
226 use vortex_error::VortexResult;
227 use vortex_io::runtime::single::block_on;
228 use vortex_io::session::RuntimeSessionExt;
229 use vortex_mask::AllOr;
230 use vortex_mask::Mask;
231 use vortex_utils::aliases::hash_set::HashSet;
232
233 use crate::LayoutStrategy;
234 use crate::layouts::flat::writer::FlatLayoutStrategy;
235 use crate::segments::TestSegments;
236 use crate::sequence::SequenceId;
237 use crate::sequence::SequentialArrayStreamExt;
238 use crate::test::SESSION;
239
240 #[should_panic]
243 #[test]
244 fn flat_stats() {
245 block_on(|handle| async {
246 let session = SESSION.clone().with_handle(handle);
247 let ctx = ArrayContext::empty();
248 let segments = Arc::new(TestSegments::default());
249 let (ptr, eof) = SequenceId::root().split();
250 let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
251 let layout = FlatLayoutStrategy::default()
252 .write_stream(
253 ctx,
254 Arc::<TestSegments>::clone(&segments),
255 array.into_array().to_array_stream().sequenced(ptr),
256 eof,
257 &session,
258 )
259 .await
260 .unwrap();
261
262 let result = layout
263 .new_reader("".into(), segments, &SESSION)
264 .unwrap()
265 .projection_evaluation(
266 &(0..layout.row_count()),
267 &root(),
268 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
269 )
270 .unwrap()
271 .await
272 .unwrap();
273
274 assert_eq!(
275 result.statistics().get_as::<bool>(Stat::IsSorted),
276 Some(Precision::Exact(true))
277 );
278 })
279 }
280
281 #[test]
282 fn truncates_variable_size_stats() {
283 block_on(|handle| async {
284 let session = SESSION.clone().with_handle(handle);
285 let ctx = ArrayContext::empty();
286 let segments = Arc::new(TestSegments::default());
287 let (ptr, eof) = SequenceId::root().split();
288 let mut builder =
289 VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
290 builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
291 builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
292 let array = builder.finish();
293 let mut stats_ctx = session.create_execution_ctx();
294 array.statistics().set_iter(
295 array
296 .statistics()
297 .compute_all(&Stat::all().collect::<Vec<_>>(), &mut stats_ctx)
298 .vortex_expect("stats computation should succeed for test array")
299 .into_iter(),
300 );
301
302 let layout = FlatLayoutStrategy::default()
303 .write_stream(
304 ctx,
305 Arc::<TestSegments>::clone(&segments),
306 array.into_array().to_array_stream().sequenced(ptr),
307 eof,
308 &session,
309 )
310 .await
311 .unwrap();
312
313 let result = layout
314 .new_reader("".into(), segments, &SESSION)
315 .unwrap()
316 .projection_evaluation(
317 &(0..layout.row_count()),
318 &root(),
319 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
320 )
321 .unwrap()
322 .await
323 .unwrap();
324
325 assert_eq!(
326 result.statistics().get_as::<String>(Stat::Min),
327 Some(Precision::Inexact(
329 "Another string that's meant to be smaller than the previous valu".to_string()
331 ))
332 );
333 assert_eq!(
334 result.statistics().get_as::<String>(Stat::Max),
335 Some(Precision::Inexact(
336 "Long value to test that the statistics are actually truncated, j".to_string()
337 ))
338 );
339 })
340 }
341
342 #[test]
343 fn struct_array_round_trip() {
344 block_on(|handle| async {
345 let mut ctx_exec = LEGACY_SESSION.create_execution_ctx();
346 let session = SESSION.clone().with_handle(handle);
347 let mut validity_builder = BitBufferMut::with_capacity(2);
348 validity_builder.append(true);
349 validity_builder.append(false);
350 let validity_boolean_buffer = validity_builder.freeze();
351 let validity = Validity::Array(
352 BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
353 );
354 let array = StructArray::try_new(
355 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
356 vec![
357 buffer![1_u64, 2].into_array(),
358 buffer![3_u64, 4].into_array(),
359 ],
360 2,
361 validity,
362 )
363 .unwrap();
364
365 let ctx = ArrayContext::empty();
366
367 let (layout, segments) = {
369 let segments = Arc::new(TestSegments::default());
370 let (ptr, eof) = SequenceId::root().split();
371 let layout = FlatLayoutStrategy::default()
372 .write_stream(
373 ctx,
374 Arc::<TestSegments>::clone(&segments),
375 array.into_array().to_array_stream().sequenced(ptr),
376 eof,
377 &session,
378 )
379 .await
380 .unwrap();
381
382 (layout, segments)
383 };
384
385 let result: ArrayRef = layout
387 .new_reader("".into(), segments, &SESSION)
388 .unwrap()
389 .projection_evaluation(
390 &(0..layout.row_count()),
391 &root(),
392 MaskFuture::new_true(layout.row_count().try_into().unwrap()),
393 )
394 .unwrap()
395 .await
396 .unwrap();
397
398 assert_eq!(
399 result
400 .validity()
401 .unwrap()
402 .execute_mask(result.len(), &mut ctx_exec)
403 .unwrap()
404 .bit_buffer(),
405 AllOr::Some(&validity_boolean_buffer)
406 );
407 let result_struct = result
408 .clone()
409 .execute::<StructArray>(&mut ctx_exec)
410 .unwrap();
411 let field_a = result_struct
412 .unmasked_field_by_name("a")
413 .unwrap()
414 .clone()
415 .execute::<PrimitiveArray>(&mut ctx_exec)
416 .unwrap();
417 assert_eq!(field_a.as_slice::<u64>(), &[1, 2]);
418 let result_struct_b = result.execute::<StructArray>(&mut ctx_exec).unwrap();
419 let field_b = result_struct_b
420 .unmasked_field_by_name("b")
421 .unwrap()
422 .clone()
423 .execute::<PrimitiveArray>(&mut ctx_exec)
424 .unwrap();
425 assert_eq!(field_b.as_slice::<u64>(), &[3, 4]);
426 })
427 }
428
429 #[test]
430 fn flat_invalid_array_fails() -> VortexResult<()> {
431 block_on(|handle| async {
432 let session = SESSION.clone().with_handle(handle);
433 let prim: PrimitiveArray = (0..10).collect();
434 let filter = prim.filter(Mask::from_indices(10, vec![2, 3]))?;
435
436 let ctx = ArrayContext::empty();
437
438 let (layout, _segments) = {
440 let segments = Arc::new(TestSegments::default());
441 let (ptr, eof) = SequenceId::root().split();
442 let allowed = HashSet::default();
444 let layout = FlatLayoutStrategy::default()
445 .with_allow_encodings(allowed)
446 .write_stream(
447 ctx,
448 Arc::<TestSegments>::clone(&segments),
449 filter.into_array().to_array_stream().sequenced(ptr),
450 eof,
451 &session,
452 )
453 .await;
454
455 (layout, segments)
456 };
457
458 let err = layout.expect_err("expected error");
459 assert!(
460 err.to_string()
461 .contains("normalize forbids encoding (vortex.filter)"),
462 "unexpected error: {err}"
463 );
464
465 Ok(())
466 })
467 }
468
469 #[test]
470 fn flat_valid_array_writes() -> VortexResult<()> {
471 block_on(|handle| async {
472 let session = SESSION.clone().with_handle(handle);
473 let codes: PrimitiveArray = (0u32..10).collect();
474 let values: PrimitiveArray = (0..10).collect();
475 let dict = DictArray::new(codes.into_array(), values.into_array());
476
477 let ctx = ArrayContext::empty();
478
479 let (layout, _segments) = {
481 let segments = Arc::new(TestSegments::default());
482 let (ptr, eof) = SequenceId::root().split();
483 let mut allowed = HashSet::default();
485 allowed.insert(Dict.id());
486 let layout = FlatLayoutStrategy::default()
487 .with_allow_encodings(allowed)
488 .write_stream(
489 ctx,
490 Arc::<TestSegments>::clone(&segments),
491 dict.into_array().to_array_stream().sequenced(ptr),
492 eof,
493 &session,
494 )
495 .await;
496
497 (layout, segments)
498 };
499
500 assert!(layout.is_ok());
501
502 Ok(())
503 })
504 }
505}