Skip to main content

vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::ArrayContext;
7use vortex_array::ArrayId;
8use vortex_array::dtype::DType;
9use vortex_array::expr::stats::Precision;
10use vortex_array::expr::stats::Stat;
11use vortex_array::expr::stats::StatsProvider;
12use vortex_array::normalize::NormalizeOptions;
13use vortex_array::normalize::Operation;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar::ScalarTruncation;
16use vortex_array::scalar::lower_bound;
17use vortex_array::scalar::upper_bound;
18use vortex_array::serde::SerializeOptions;
19use vortex_array::stats::StatsSetRef;
20use vortex_buffer::BufferString;
21use vortex_buffer::ByteBuffer;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_bail;
25use vortex_session::VortexSession;
26use vortex_session::registry::ReadContext;
27use vortex_utils::aliases::hash_set::HashSet;
28
29use crate::IntoLayout;
30use crate::LayoutRef;
31use crate::LayoutStrategy;
32use crate::layouts::flat::FlatLayout;
33use crate::layouts::flat::flat_layout_inline_array_node;
34use crate::segments::SegmentSinkRef;
35use crate::sequence::SendableSequentialStream;
36use crate::sequence::SequencePointer;
37
38#[derive(Clone)]
39pub struct FlatLayoutStrategy {
40    /// Whether to include padding for memory-mapped reads.
41    pub include_padding: bool,
42    /// Maximum length of variable length statistics
43    pub max_variable_length_statistics_size: usize,
44    /// Optional set of allowed array encodings for normalization.
45    /// If None, then all are allowed.
46    pub allowed_encodings: Option<HashSet<ArrayId>>,
47}
48
49impl Default for FlatLayoutStrategy {
50    fn default() -> Self {
51        Self {
52            include_padding: true,
53            max_variable_length_statistics_size: 64,
54            allowed_encodings: None,
55        }
56    }
57}
58
59impl FlatLayoutStrategy {
60    /// Set whether to include padding for memory-mapped reads.
61    pub fn with_include_padding(mut self, include_padding: bool) -> Self {
62        self.include_padding = include_padding;
63        self
64    }
65
66    /// Set the maximum length of variable length statistics.
67    pub fn with_max_variable_length_statistics_size(mut self, size: usize) -> Self {
68        self.max_variable_length_statistics_size = size;
69        self
70    }
71
72    /// Set the allowed array encodings for normalization.
73    pub fn with_allow_encodings(mut self, allow_encodings: HashSet<ArrayId>) -> Self {
74        self.allowed_encodings = Some(allow_encodings);
75        self
76    }
77}
78
79fn truncate_scalar_stat<F: Fn(Scalar) -> Option<(Scalar, bool)>>(
80    statistics: StatsSetRef<'_>,
81    stat: Stat,
82    truncation: F,
83) {
84    if let Some(sv) = statistics.get(stat) {
85        if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) {
86            if truncated && let Some(v) = truncated_value.into_value() {
87                statistics.set(stat, Precision::Inexact(v));
88            }
89        } else {
90            statistics.clear(stat)
91        }
92    }
93}
94
95#[async_trait]
96impl LayoutStrategy for FlatLayoutStrategy {
97    async fn write_stream(
98        &self,
99        ctx: ArrayContext,
100        segment_sink: SegmentSinkRef,
101        mut stream: SendableSequentialStream,
102        _eof: SequencePointer,
103        session: &VortexSession,
104    ) -> VortexResult<LayoutRef> {
105        let ctx = ctx.clone();
106        let Some(chunk) = stream.next().await else {
107            vortex_bail!("flat layout needs a single chunk");
108        };
109        let (sequence_id, chunk) = chunk?;
110
111        let row_count = chunk.len() as u64;
112
113        match chunk.dtype() {
114            DType::Utf8(n) => {
115                truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
116                    lower_bound(
117                        BufferString::from_scalar(v)
118                            .vortex_expect("utf8 scalar must be a BufferString"),
119                        self.max_variable_length_statistics_size,
120                        *n,
121                    )
122                });
123                truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
124                    upper_bound(
125                        BufferString::from_scalar(v)
126                            .vortex_expect("utf8 scalar must be a BufferString"),
127                        self.max_variable_length_statistics_size,
128                        *n,
129                    )
130                });
131            }
132            DType::Binary(n) => {
133                truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
134                    lower_bound(
135                        ByteBuffer::from_scalar(v)
136                            .vortex_expect("binary scalar must be a ByteBuffer"),
137                        self.max_variable_length_statistics_size,
138                        *n,
139                    )
140                });
141                truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
142                    upper_bound(
143                        ByteBuffer::from_scalar(v)
144                            .vortex_expect("binary scalar must be a ByteBuffer"),
145                        self.max_variable_length_statistics_size,
146                        *n,
147                    )
148                });
149            }
150            _ => {}
151        }
152
153        let chunk = if let Some(allowed) = &self.allowed_encodings {
154            chunk.normalize(&mut NormalizeOptions {
155                allowed,
156                operation: Operation::Error,
157            })?
158        } else {
159            chunk
160        };
161
162        let buffers = chunk.serialize(
163            &ctx,
164            session,
165            &SerializeOptions {
166                offset: 0,
167                include_padding: self.include_padding,
168            },
169        )?;
170        // there is at least the flatbuffer and the length
171        assert!(buffers.len() >= 2);
172        let array_node =
173            flat_layout_inline_array_node().then(|| buffers[buffers.len() - 2].clone());
174        let segment_id = segment_sink.write(sequence_id, buffers).await?;
175
176        let None = stream.next().await else {
177            vortex_bail!("flat layout received stream with more than a single chunk");
178        };
179        Ok(FlatLayout::new_with_metadata(
180            row_count,
181            stream.dtype().clone(),
182            segment_id,
183            ReadContext::new(ctx.to_ids()),
184            array_node,
185        )
186        .into_layout())
187    }
188
189    fn buffered_bytes(&self) -> u64 {
190        // FlatLayoutStrategy is a leaf strategy with no child strategies and no buffering
191        0
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use std::sync::Arc;
198
199    use vortex_array::ArrayContext;
200    use vortex_array::ArrayRef;
201    use vortex_array::IntoArray;
202    use vortex_array::LEGACY_SESSION;
203    use vortex_array::MaskFuture;
204    use vortex_array::VortexSessionExecute;
205    use vortex_array::arrays::BoolArray;
206    use vortex_array::arrays::Dict;
207    use vortex_array::arrays::DictArray;
208    use vortex_array::arrays::PrimitiveArray;
209    use vortex_array::arrays::StructArray;
210    use vortex_array::arrays::struct_::StructArrayExt;
211    use vortex_array::builders::ArrayBuilder;
212    use vortex_array::builders::VarBinViewBuilder;
213    use vortex_array::dtype::DType;
214    use vortex_array::dtype::FieldName;
215    use vortex_array::dtype::FieldNames;
216    use vortex_array::dtype::Nullability;
217    use vortex_array::expr::root;
218    use vortex_array::expr::stats::Precision;
219    use vortex_array::expr::stats::Stat;
220    use vortex_array::expr::stats::StatsProviderExt;
221    use vortex_array::validity::Validity;
222    use vortex_array::vtable::VTable;
223    use vortex_buffer::BitBufferMut;
224    use vortex_buffer::buffer;
225    use vortex_error::VortexExpect;
226    use vortex_error::VortexResult;
227    use vortex_io::runtime::single::block_on;
228    use vortex_io::session::RuntimeSessionExt;
229    use vortex_mask::AllOr;
230    use vortex_mask::Mask;
231    use vortex_utils::aliases::hash_set::HashSet;
232
233    use crate::LayoutStrategy;
234    use crate::layouts::flat::writer::FlatLayoutStrategy;
235    use crate::segments::TestSegments;
236    use crate::sequence::SequenceId;
237    use crate::sequence::SequentialArrayStreamExt;
238    use crate::test::SESSION;
239
240    // Currently, flat layouts do not force compute stats during write, they only retain
241    // pre-computed stats.
242    #[should_panic]
243    #[test]
244    fn flat_stats() {
245        block_on(|handle| async {
246            let session = SESSION.clone().with_handle(handle);
247            let ctx = ArrayContext::empty();
248            let segments = Arc::new(TestSegments::default());
249            let (ptr, eof) = SequenceId::root().split();
250            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
251            let layout = FlatLayoutStrategy::default()
252                .write_stream(
253                    ctx,
254                    Arc::<TestSegments>::clone(&segments),
255                    array.into_array().to_array_stream().sequenced(ptr),
256                    eof,
257                    &session,
258                )
259                .await
260                .unwrap();
261
262            let result = layout
263                .new_reader("".into(), segments, &SESSION)
264                .unwrap()
265                .projection_evaluation(
266                    &(0..layout.row_count()),
267                    &root(),
268                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
269                )
270                .unwrap()
271                .await
272                .unwrap();
273
274            assert_eq!(
275                result.statistics().get_as::<bool>(Stat::IsSorted),
276                Some(Precision::Exact(true))
277            );
278        })
279    }
280
281    #[test]
282    fn truncates_variable_size_stats() {
283        block_on(|handle| async {
284            let session = SESSION.clone().with_handle(handle);
285            let ctx = ArrayContext::empty();
286            let segments = Arc::new(TestSegments::default());
287            let (ptr, eof) = SequenceId::root().split();
288            let mut builder =
289                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
290            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
291            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
292            let array = builder.finish();
293            let mut stats_ctx = session.create_execution_ctx();
294            array.statistics().set_iter(
295                array
296                    .statistics()
297                    .compute_all(&Stat::all().collect::<Vec<_>>(), &mut stats_ctx)
298                    .vortex_expect("stats computation should succeed for test array")
299                    .into_iter(),
300            );
301
302            let layout = FlatLayoutStrategy::default()
303                .write_stream(
304                    ctx,
305                    Arc::<TestSegments>::clone(&segments),
306                    array.into_array().to_array_stream().sequenced(ptr),
307                    eof,
308                    &session,
309                )
310                .await
311                .unwrap();
312
313            let result = layout
314                .new_reader("".into(), segments, &SESSION)
315                .unwrap()
316                .projection_evaluation(
317                    &(0..layout.row_count()),
318                    &root(),
319                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
320                )
321                .unwrap()
322                .await
323                .unwrap();
324
325            assert_eq!(
326                result.statistics().get_as::<String>(Stat::Min),
327                // The typo is correct, we need this to be truncated.
328                Some(Precision::Inexact(
329                    // spellchecker:ignore-next-line
330                    "Another string that's meant to be smaller than the previous valu".to_string()
331                ))
332            );
333            assert_eq!(
334                result.statistics().get_as::<String>(Stat::Max),
335                Some(Precision::Inexact(
336                    "Long value to test that the statistics are actually truncated, j".to_string()
337                ))
338            );
339        })
340    }
341
342    #[test]
343    fn struct_array_round_trip() {
344        block_on(|handle| async {
345            let mut ctx_exec = LEGACY_SESSION.create_execution_ctx();
346            let session = SESSION.clone().with_handle(handle);
347            let mut validity_builder = BitBufferMut::with_capacity(2);
348            validity_builder.append(true);
349            validity_builder.append(false);
350            let validity_boolean_buffer = validity_builder.freeze();
351            let validity = Validity::Array(
352                BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
353            );
354            let array = StructArray::try_new(
355                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
356                vec![
357                    buffer![1_u64, 2].into_array(),
358                    buffer![3_u64, 4].into_array(),
359                ],
360                2,
361                validity,
362            )
363            .unwrap();
364
365            let ctx = ArrayContext::empty();
366
367            // Write the array into a byte buffer.
368            let (layout, segments) = {
369                let segments = Arc::new(TestSegments::default());
370                let (ptr, eof) = SequenceId::root().split();
371                let layout = FlatLayoutStrategy::default()
372                    .write_stream(
373                        ctx,
374                        Arc::<TestSegments>::clone(&segments),
375                        array.into_array().to_array_stream().sequenced(ptr),
376                        eof,
377                        &session,
378                    )
379                    .await
380                    .unwrap();
381
382                (layout, segments)
383            };
384
385            // We should be able to read the array we just wrote.
386            let result: ArrayRef = layout
387                .new_reader("".into(), segments, &SESSION)
388                .unwrap()
389                .projection_evaluation(
390                    &(0..layout.row_count()),
391                    &root(),
392                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
393                )
394                .unwrap()
395                .await
396                .unwrap();
397
398            assert_eq!(
399                result
400                    .validity()
401                    .unwrap()
402                    .execute_mask(result.len(), &mut ctx_exec)
403                    .unwrap()
404                    .bit_buffer(),
405                AllOr::Some(&validity_boolean_buffer)
406            );
407            let result_struct = result
408                .clone()
409                .execute::<StructArray>(&mut ctx_exec)
410                .unwrap();
411            let field_a = result_struct
412                .unmasked_field_by_name("a")
413                .unwrap()
414                .clone()
415                .execute::<PrimitiveArray>(&mut ctx_exec)
416                .unwrap();
417            assert_eq!(field_a.as_slice::<u64>(), &[1, 2]);
418            let result_struct_b = result.execute::<StructArray>(&mut ctx_exec).unwrap();
419            let field_b = result_struct_b
420                .unmasked_field_by_name("b")
421                .unwrap()
422                .clone()
423                .execute::<PrimitiveArray>(&mut ctx_exec)
424                .unwrap();
425            assert_eq!(field_b.as_slice::<u64>(), &[3, 4]);
426        })
427    }
428
429    #[test]
430    fn flat_invalid_array_fails() -> VortexResult<()> {
431        block_on(|handle| async {
432            let session = SESSION.clone().with_handle(handle);
433            let prim: PrimitiveArray = (0..10).collect();
434            let filter = prim.filter(Mask::from_indices(10, vec![2, 3]))?;
435
436            let ctx = ArrayContext::empty();
437
438            // Write the array into a byte buffer.
439            let (layout, _segments) = {
440                let segments = Arc::new(TestSegments::default());
441                let (ptr, eof) = SequenceId::root().split();
442                // Disallow all encodings so filter arrays fail normalization immediately.
443                let allowed = HashSet::default();
444                let layout = FlatLayoutStrategy::default()
445                    .with_allow_encodings(allowed)
446                    .write_stream(
447                        ctx,
448                        Arc::<TestSegments>::clone(&segments),
449                        filter.into_array().to_array_stream().sequenced(ptr),
450                        eof,
451                        &session,
452                    )
453                    .await;
454
455                (layout, segments)
456            };
457
458            let err = layout.expect_err("expected error");
459            assert!(
460                err.to_string()
461                    .contains("normalize forbids encoding (vortex.filter)"),
462                "unexpected error: {err}"
463            );
464
465            Ok(())
466        })
467    }
468
469    #[test]
470    fn flat_valid_array_writes() -> VortexResult<()> {
471        block_on(|handle| async {
472            let session = SESSION.clone().with_handle(handle);
473            let codes: PrimitiveArray = (0u32..10).collect();
474            let values: PrimitiveArray = (0..10).collect();
475            let dict = DictArray::new(codes.into_array(), values.into_array());
476
477            let ctx = ArrayContext::empty();
478
479            // Write the array into a byte buffer.
480            let (layout, _segments) = {
481                let segments = Arc::new(TestSegments::default());
482                let (ptr, eof) = SequenceId::root().split();
483                // Only allow the dict encoding; canonical primitive children remain permitted.
484                let mut allowed = HashSet::default();
485                allowed.insert(Dict.id());
486                let layout = FlatLayoutStrategy::default()
487                    .with_allow_encodings(allowed)
488                    .write_stream(
489                        ctx,
490                        Arc::<TestSegments>::clone(&segments),
491                        dict.into_array().to_array_stream().sequenced(ptr),
492                        eof,
493                        &session,
494                    )
495                    .await;
496
497                (layout, segments)
498            };
499
500            assert!(layout.is_ok());
501
502            Ok(())
503        })
504    }
505}