Skip to main content

vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::Array;
7use vortex_array::ArrayContext;
8use vortex_array::expr::stats::Precision;
9use vortex_array::expr::stats::Stat;
10use vortex_array::expr::stats::StatsProvider;
11use vortex_array::normalize::NormalizeOptions;
12use vortex_array::normalize::Operation;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar::ScalarTruncation;
15use vortex_array::scalar::lower_bound;
16use vortex_array::scalar::upper_bound;
17use vortex_array::serde::SerializeOptions;
18use vortex_array::session::ArrayRegistry;
19use vortex_array::stats::StatsSetRef;
20use vortex_buffer::BufferString;
21use vortex_buffer::ByteBuffer;
22use vortex_dtype::DType;
23use vortex_error::VortexExpect;
24use vortex_error::VortexResult;
25use vortex_error::vortex_bail;
26use vortex_io::runtime::Handle;
27
28use crate::IntoLayout;
29use crate::LayoutRef;
30use crate::LayoutStrategy;
31use crate::layouts::flat::FlatLayout;
32use crate::layouts::flat::flat_layout_inline_array_node;
33use crate::segments::SegmentSinkRef;
34use crate::sequence::SendableSequentialStream;
35use crate::sequence::SequencePointer;
36
37#[derive(Clone)]
38pub struct FlatLayoutStrategy {
39    /// Whether to include padding for memory-mapped reads.
40    pub include_padding: bool,
41    /// Maximum length of variable length statistics
42    pub max_variable_length_statistics_size: usize,
43    /// Optional set of allowed array encodings for normalization.
44    /// If None, then all are allowed.
45    pub allowed_encodings: Option<ArrayRegistry>,
46}
47
48impl Default for FlatLayoutStrategy {
49    fn default() -> Self {
50        Self {
51            include_padding: true,
52            max_variable_length_statistics_size: 64,
53            allowed_encodings: None,
54        }
55    }
56}
57
58impl FlatLayoutStrategy {
59    /// Set whether to include padding for memory-mapped reads.
60    pub fn with_include_padding(mut self, include_padding: bool) -> Self {
61        self.include_padding = include_padding;
62        self
63    }
64
65    /// Set the maximum length of variable length statistics.
66    pub fn with_max_variable_length_statistics_size(mut self, size: usize) -> Self {
67        self.max_variable_length_statistics_size = size;
68        self
69    }
70
71    /// Set the allowed array encodings for normalization.
72    pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self {
73        self.allowed_encodings = Some(allow_encodings);
74        self
75    }
76}
77
78fn truncate_scalar_stat<F: Fn(Scalar) -> Option<(Scalar, bool)>>(
79    statistics: StatsSetRef<'_>,
80    stat: Stat,
81    truncation: F,
82) {
83    if let Some(sv) = statistics.get(stat) {
84        if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) {
85            if truncated && let Some(v) = truncated_value.into_value() {
86                statistics.set(stat, Precision::Inexact(v));
87            }
88        } else {
89            statistics.clear(stat)
90        }
91    }
92}
93
94#[async_trait]
95impl LayoutStrategy for FlatLayoutStrategy {
96    async fn write_stream(
97        &self,
98        ctx: ArrayContext,
99        segment_sink: SegmentSinkRef,
100        mut stream: SendableSequentialStream,
101        _eof: SequencePointer,
102        _handle: Handle,
103    ) -> VortexResult<LayoutRef> {
104        let ctx = ctx.clone();
105        let Some(chunk) = stream.next().await else {
106            vortex_bail!("flat layout needs a single chunk");
107        };
108        let (sequence_id, chunk) = chunk?;
109
110        let row_count = chunk.len() as u64;
111
112        match chunk.dtype() {
113            DType::Utf8(n) => {
114                truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
115                    lower_bound(
116                        BufferString::from_scalar(v)
117                            .vortex_expect("utf8 scalar must be a BufferString"),
118                        self.max_variable_length_statistics_size,
119                        *n,
120                    )
121                });
122                truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
123                    upper_bound(
124                        BufferString::from_scalar(v)
125                            .vortex_expect("utf8 scalar must be a BufferString"),
126                        self.max_variable_length_statistics_size,
127                        *n,
128                    )
129                });
130            }
131            DType::Binary(n) => {
132                truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| {
133                    lower_bound(
134                        ByteBuffer::from_scalar(v)
135                            .vortex_expect("binary scalar must be a ByteBuffer"),
136                        self.max_variable_length_statistics_size,
137                        *n,
138                    )
139                });
140                truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| {
141                    upper_bound(
142                        ByteBuffer::from_scalar(v)
143                            .vortex_expect("binary scalar must be a ByteBuffer"),
144                        self.max_variable_length_statistics_size,
145                        *n,
146                    )
147                });
148            }
149            _ => {}
150        }
151
152        let chunk = if let Some(allowed) = &self.allowed_encodings {
153            chunk.normalize(&mut NormalizeOptions {
154                allowed,
155                operation: Operation::Error,
156            })?
157        } else {
158            chunk
159        };
160
161        let buffers = chunk.serialize(
162            &ctx,
163            &SerializeOptions {
164                offset: 0,
165                include_padding: self.include_padding,
166            },
167        )?;
168        // there is at least the flatbuffer and the length
169        assert!(buffers.len() >= 2);
170        let array_node =
171            flat_layout_inline_array_node().then(|| buffers[buffers.len() - 2].clone());
172        let segment_id = segment_sink.write(sequence_id, buffers).await?;
173
174        let None = stream.next().await else {
175            vortex_bail!("flat layout received stream with more than a single chunk");
176        };
177        Ok(FlatLayout::new_with_metadata(
178            row_count,
179            stream.dtype().clone(),
180            segment_id,
181            ctx.clone(),
182            array_node,
183        )
184        .into_layout())
185    }
186
187    fn buffered_bytes(&self) -> u64 {
188        // FlatLayoutStrategy is a leaf strategy with no child strategies and no buffering
189        0
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use std::sync::Arc;
196
197    use vortex_array::Array;
198    use vortex_array::ArrayContext;
199    use vortex_array::ArrayRef;
200    use vortex_array::IntoArray;
201    use vortex_array::MaskFuture;
202    use vortex_array::ToCanonical;
203    use vortex_array::arrays::BoolArray;
204    use vortex_array::arrays::DictArray;
205    use vortex_array::arrays::DictVTable;
206    use vortex_array::arrays::PrimitiveArray;
207    use vortex_array::arrays::PrimitiveVTable;
208    use vortex_array::arrays::StructArray;
209    use vortex_array::builders::ArrayBuilder;
210    use vortex_array::builders::VarBinViewBuilder;
211    use vortex_array::expr::root;
212    use vortex_array::expr::stats::Precision;
213    use vortex_array::expr::stats::Stat;
214    use vortex_array::expr::stats::StatsProviderExt;
215    use vortex_array::session::ArrayRegistry;
216    use vortex_array::validity::Validity;
217    use vortex_buffer::BitBufferMut;
218    use vortex_buffer::buffer;
219    use vortex_dtype::DType;
220    use vortex_dtype::FieldName;
221    use vortex_dtype::FieldNames;
222    use vortex_dtype::Nullability;
223    use vortex_error::VortexExpect;
224    use vortex_error::VortexResult;
225    use vortex_io::runtime::single::block_on;
226    use vortex_mask::AllOr;
227    use vortex_mask::Mask;
228
229    use crate::LayoutStrategy;
230    use crate::layouts::flat::writer::FlatLayoutStrategy;
231    use crate::segments::TestSegments;
232    use crate::sequence::SequenceId;
233    use crate::sequence::SequentialArrayStreamExt;
234    use crate::test::SESSION;
235
236    // Currently, flat layouts do not force compute stats during write, they only retain
237    // pre-computed stats.
238    #[should_panic]
239    #[test]
240    fn flat_stats() {
241        block_on(|handle| async {
242            let ctx = ArrayContext::empty();
243            let segments = Arc::new(TestSegments::default());
244            let (ptr, eof) = SequenceId::root().split();
245            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
246            let layout = FlatLayoutStrategy::default()
247                .write_stream(
248                    ctx,
249                    segments.clone(),
250                    array.to_array_stream().sequenced(ptr),
251                    eof,
252                    handle,
253                )
254                .await
255                .unwrap();
256
257            let result = layout
258                .new_reader("".into(), segments, &SESSION)
259                .unwrap()
260                .projection_evaluation(
261                    &(0..layout.row_count()),
262                    &root(),
263                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
264                )
265                .unwrap()
266                .await
267                .unwrap();
268
269            assert_eq!(
270                result.statistics().get_as::<bool>(Stat::IsSorted),
271                Some(Precision::Exact(true))
272            );
273        })
274    }
275
276    #[test]
277    fn truncates_variable_size_stats() {
278        block_on(|handle| async {
279            let ctx = ArrayContext::empty();
280            let segments = Arc::new(TestSegments::default());
281            let (ptr, eof) = SequenceId::root().split();
282            let mut builder =
283                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
284            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
285            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
286            let array = builder.finish();
287            array.statistics().set_iter(
288                array
289                    .statistics()
290                    .compute_all(&Stat::all().collect::<Vec<_>>())
291                    .vortex_expect("stats computation should succeed for test array")
292                    .into_iter(),
293            );
294
295            let layout = FlatLayoutStrategy::default()
296                .write_stream(
297                    ctx,
298                    segments.clone(),
299                    array.to_array_stream().sequenced(ptr),
300                    eof,
301                    handle,
302                )
303                .await
304                .unwrap();
305
306            let result = layout
307                .new_reader("".into(), segments, &SESSION)
308                .unwrap()
309                .projection_evaluation(
310                    &(0..layout.row_count()),
311                    &root(),
312                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
313                )
314                .unwrap()
315                .await
316                .unwrap();
317
318            assert_eq!(
319                result.statistics().get_as::<String>(Stat::Min),
320                // The typo is correct, we need this to be truncated.
321                Some(Precision::Inexact(
322                    // spellchecker:ignore-next-line
323                    "Another string that's meant to be smaller than the previous valu".to_string()
324                ))
325            );
326            assert_eq!(
327                result.statistics().get_as::<String>(Stat::Max),
328                Some(Precision::Inexact(
329                    "Long value to test that the statistics are actually truncated, j".to_string()
330                ))
331            );
332        })
333    }
334
335    #[test]
336    fn struct_array_round_trip() {
337        block_on(|handle| async {
338            let mut validity_builder = BitBufferMut::with_capacity(2);
339            validity_builder.append(true);
340            validity_builder.append(false);
341            let validity_boolean_buffer = validity_builder.freeze();
342            let validity = Validity::Array(
343                BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
344            );
345            let array = StructArray::try_new(
346                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
347                vec![
348                    buffer![1_u64, 2].into_array(),
349                    buffer![3_u64, 4].into_array(),
350                ],
351                2,
352                validity,
353            )
354            .unwrap();
355
356            let ctx = ArrayContext::empty();
357
358            // Write the array into a byte buffer.
359            let (layout, segments) = {
360                let segments = Arc::new(TestSegments::default());
361                let (ptr, eof) = SequenceId::root().split();
362                let layout = FlatLayoutStrategy::default()
363                    .write_stream(
364                        ctx,
365                        segments.clone(),
366                        array.to_array_stream().sequenced(ptr),
367                        eof,
368                        handle,
369                    )
370                    .await
371                    .unwrap();
372
373                (layout, segments)
374            };
375
376            // We should be able to read the array we just wrote.
377            let result: ArrayRef = layout
378                .new_reader("".into(), segments, &SESSION)
379                .unwrap()
380                .projection_evaluation(
381                    &(0..layout.row_count()),
382                    &root(),
383                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
384                )
385                .unwrap()
386                .await
387                .unwrap();
388
389            assert_eq!(
390                result.validity_mask().unwrap().bit_buffer(),
391                AllOr::Some(&validity_boolean_buffer)
392            );
393            assert_eq!(
394                result
395                    .to_struct()
396                    .unmasked_field_by_name("a")
397                    .unwrap()
398                    .to_primitive()
399                    .as_slice::<u64>(),
400                &[1, 2]
401            );
402            assert_eq!(
403                result
404                    .to_struct()
405                    .unmasked_field_by_name("b")
406                    .unwrap()
407                    .to_primitive()
408                    .as_slice::<u64>(),
409                &[3, 4]
410            );
411        })
412    }
413
414    #[test]
415    fn flat_invalid_array_fails() -> VortexResult<()> {
416        block_on(|handle| async {
417            let prim: PrimitiveArray = (0..10).collect();
418            let filter = prim.filter(Mask::from_indices(10, vec![2, 3]))?;
419
420            let ctx = ArrayContext::empty();
421
422            // Write the array into a byte buffer.
423            let (layout, _segments) = {
424                let segments = Arc::new(TestSegments::default());
425                let (ptr, eof) = SequenceId::root().split();
426                // Only allow primitive encodings - filter arrays should fail.
427                let allowed = ArrayRegistry::default();
428                allowed.register(PrimitiveVTable::ID, PrimitiveVTable);
429                let layout = FlatLayoutStrategy::default()
430                    .with_allow_encodings(allowed)
431                    .write_stream(
432                        ctx,
433                        segments.clone(),
434                        filter.to_array_stream().sequenced(ptr),
435                        eof,
436                        handle,
437                    )
438                    .await;
439
440                (layout, segments)
441            };
442
443            let err = layout.expect_err("expected error");
444            assert!(
445                err.to_string()
446                    .contains("normalize forbids encoding (vortex.filter)"),
447                "unexpected error: {err}"
448            );
449
450            Ok(())
451        })
452    }
453
454    #[test]
455    fn flat_valid_array_writes() -> VortexResult<()> {
456        block_on(|handle| async {
457            let codes: PrimitiveArray = (0u32..10).collect();
458            let values: PrimitiveArray = (0..10).collect();
459            let dict = DictArray::new(codes.into_array(), values.into_array());
460
461            let ctx = ArrayContext::empty();
462
463            // Write the array into a byte buffer.
464            let (layout, _segments) = {
465                let segments = Arc::new(TestSegments::default());
466                let (ptr, eof) = SequenceId::root().split();
467                // Only allow primitive encodings - filter arrays should fail.
468                let allowed = ArrayRegistry::default();
469                allowed.register(PrimitiveVTable::ID, PrimitiveVTable);
470                allowed.register(DictVTable::ID, DictVTable);
471                let layout = FlatLayoutStrategy::default()
472                    .with_allow_encodings(allowed)
473                    .write_stream(
474                        ctx,
475                        segments.clone(),
476                        dict.to_array_stream().sequenced(ptr),
477                        eof,
478                        handle,
479                    )
480                    .await;
481
482                (layout, segments)
483            };
484
485            assert!(layout.is_ok());
486
487            Ok(())
488        })
489    }
490}