vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::Array;
7use vortex_array::ArrayContext;
8use vortex_array::expr::stats::Precision;
9use vortex_array::expr::stats::Stat;
10use vortex_array::expr::stats::StatsProvider;
11use vortex_array::serde::SerializeOptions;
12use vortex_dtype::DType;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_io::runtime::Handle;
16
17use crate::IntoLayout;
18use crate::LayoutRef;
19use crate::LayoutStrategy;
20use crate::layouts::flat::FLAT_LAYOUT_INLINE_ARRAY_NODE;
21use crate::layouts::flat::FlatLayout;
22use crate::layouts::zoned::lower_bound;
23use crate::layouts::zoned::upper_bound;
24use crate::segments::SegmentSinkRef;
25use crate::sequence::SendableSequentialStream;
26use crate::sequence::SequencePointer;
27
28#[derive(Clone)]
29pub struct FlatLayoutStrategy {
30    /// Whether to include padding for memory-mapped reads.
31    pub include_padding: bool,
32    /// Maximum length of variable length statistics
33    pub max_variable_length_statistics_size: usize,
34}
35
36impl Default for FlatLayoutStrategy {
37    fn default() -> Self {
38        Self {
39            include_padding: true,
40            max_variable_length_statistics_size: 64,
41        }
42    }
43}
44
45#[async_trait]
46impl LayoutStrategy for FlatLayoutStrategy {
47    async fn write_stream(
48        &self,
49        ctx: ArrayContext,
50        segment_sink: SegmentSinkRef,
51        mut stream: SendableSequentialStream,
52        _eof: SequencePointer,
53        _handle: Handle,
54    ) -> VortexResult<LayoutRef> {
55        let ctx = ctx.clone();
56        let options = self.clone();
57        let Some(chunk) = stream.next().await else {
58            vortex_bail!("flat layout needs a single chunk");
59        };
60        let (sequence_id, chunk) = chunk?;
61
62        let row_count = chunk.len() as u64;
63
64        match chunk.dtype() {
65            DType::Utf8(_) => {
66                if let Some(sv) = chunk.statistics().get(Stat::Min) {
67                    let (value, truncated) = lower_bound(
68                        sv.into_inner().as_utf8(),
69                        options.max_variable_length_statistics_size,
70                    );
71                    if truncated {
72                        chunk
73                            .statistics()
74                            .set(Stat::Min, Precision::Inexact(value.into_value()));
75                    }
76                }
77
78                if let Some(sv) = chunk.statistics().get(Stat::Max) {
79                    let (value, truncated) = upper_bound(
80                        sv.into_inner().as_utf8(),
81                        options.max_variable_length_statistics_size,
82                    );
83                    if let Some(upper_bound) = value {
84                        if truncated {
85                            chunk
86                                .statistics()
87                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
88                        }
89                    } else {
90                        chunk.statistics().clear(Stat::Max)
91                    }
92                }
93            }
94            DType::Binary(_) => {
95                if let Some(sv) = chunk.statistics().get(Stat::Min) {
96                    let (value, truncated) = lower_bound(
97                        sv.into_inner().as_binary(),
98                        options.max_variable_length_statistics_size,
99                    );
100                    if truncated {
101                        chunk
102                            .statistics()
103                            .set(Stat::Min, Precision::Inexact(value.into_value()));
104                    }
105                }
106
107                if let Some(sv) = chunk.statistics().get(Stat::Max) {
108                    let (value, truncated) = upper_bound(
109                        sv.into_inner().as_binary(),
110                        options.max_variable_length_statistics_size,
111                    );
112                    if let Some(upper_bound) = value {
113                        if truncated {
114                            chunk
115                                .statistics()
116                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
117                        }
118                    } else {
119                        chunk.statistics().clear(Stat::Max)
120                    }
121                }
122            }
123            _ => {}
124        }
125
126        // TODO(os): spawn serialization
127        let buffers = chunk.serialize(
128            &ctx,
129            &SerializeOptions {
130                offset: 0,
131                include_padding: options.include_padding,
132            },
133        )?;
134        // there is at least the flatbuffer and the length
135        assert!(buffers.len() >= 2);
136        let array_node =
137            (*FLAT_LAYOUT_INLINE_ARRAY_NODE).then(|| buffers[buffers.len() - 2].clone());
138        let segment_id = segment_sink.write(sequence_id, buffers).await?;
139
140        let None = stream.next().await else {
141            vortex_bail!("flat layout received stream with more than a single chunk");
142        };
143        Ok(FlatLayout::new_with_metadata(
144            row_count,
145            stream.dtype().clone(),
146            segment_id,
147            ctx.clone(),
148            array_node,
149        )
150        .into_layout())
151    }
152
153    fn buffered_bytes(&self) -> u64 {
154        // FlatLayoutStrategy is a leaf strategy with no child strategies and no buffering
155        0
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::sync::Arc;
162
163    use vortex_array::Array;
164    use vortex_array::ArrayContext;
165    use vortex_array::ArrayRef;
166    use vortex_array::IntoArray;
167    use vortex_array::MaskFuture;
168    use vortex_array::ToCanonical;
169    use vortex_array::arrays::BoolArray;
170    use vortex_array::arrays::PrimitiveArray;
171    use vortex_array::arrays::StructArray;
172    use vortex_array::builders::ArrayBuilder;
173    use vortex_array::builders::VarBinViewBuilder;
174    use vortex_array::expr::root;
175    use vortex_array::expr::stats::Precision;
176    use vortex_array::expr::stats::Stat;
177    use vortex_array::expr::stats::StatsProviderExt;
178    use vortex_array::validity::Validity;
179    use vortex_buffer::BitBufferMut;
180    use vortex_buffer::buffer;
181    use vortex_dtype::DType;
182    use vortex_dtype::FieldName;
183    use vortex_dtype::FieldNames;
184    use vortex_dtype::Nullability;
185    use vortex_error::VortexUnwrap;
186    use vortex_io::runtime::single::block_on;
187    use vortex_mask::AllOr;
188
189    use crate::LayoutStrategy;
190    use crate::layouts::flat::writer::FlatLayoutStrategy;
191    use crate::segments::TestSegments;
192    use crate::sequence::SequenceId;
193    use crate::sequence::SequentialArrayStreamExt;
194    use crate::test::SESSION;
195
196    // Currently, flat layouts do not force compute stats during write, they only retain
197    // pre-computed stats.
198    #[should_panic]
199    #[test]
200    fn flat_stats() {
201        block_on(|handle| async {
202            let ctx = ArrayContext::empty();
203            let segments = Arc::new(TestSegments::default());
204            let (ptr, eof) = SequenceId::root().split();
205            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
206            let layout = FlatLayoutStrategy::default()
207                .write_stream(
208                    ctx,
209                    segments.clone(),
210                    array.to_array_stream().sequenced(ptr),
211                    eof,
212                    handle,
213                )
214                .await
215                .unwrap();
216
217            let result = layout
218                .new_reader("".into(), segments, &SESSION)
219                .unwrap()
220                .projection_evaluation(
221                    &(0..layout.row_count()),
222                    &root(),
223                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
224                )
225                .unwrap()
226                .await
227                .unwrap();
228
229            assert_eq!(
230                result.statistics().get_as::<bool>(Stat::IsSorted),
231                Some(Precision::Exact(true))
232            );
233        })
234    }
235
236    #[test]
237    fn truncates_variable_size_stats() {
238        block_on(|handle| async {
239            let ctx = ArrayContext::empty();
240            let segments = Arc::new(TestSegments::default());
241            let (ptr, eof) = SequenceId::root().split();
242            let mut builder =
243                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
244            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
245            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
246            let array = builder.finish();
247            array.statistics().set_iter(
248                array
249                    .statistics()
250                    .compute_all(&Stat::all().collect::<Vec<_>>())
251                    .vortex_unwrap()
252                    .into_iter(),
253            );
254
255            let layout = FlatLayoutStrategy::default()
256                .write_stream(
257                    ctx,
258                    segments.clone(),
259                    array.to_array_stream().sequenced(ptr),
260                    eof,
261                    handle,
262                )
263                .await
264                .unwrap();
265
266            let result = layout
267                .new_reader("".into(), segments, &SESSION)
268                .unwrap()
269                .projection_evaluation(
270                    &(0..layout.row_count()),
271                    &root(),
272                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
273                )
274                .unwrap()
275                .await
276                .unwrap();
277
278            assert_eq!(
279                result.statistics().get_as::<String>(Stat::Min),
280                // The typo is correct, we need this to be truncated.
281                Some(Precision::Inexact(
282                    // spellchecker:ignore-next-line
283                    "Another string that's meant to be smaller than the previous valu".to_string()
284                ))
285            );
286            assert_eq!(
287                result.statistics().get_as::<String>(Stat::Max),
288                Some(Precision::Inexact(
289                    "Long value to test that the statistics are actually truncated, j".to_string()
290                ))
291            );
292        })
293    }
294
295    #[test]
296    fn struct_array_round_trip() {
297        block_on(|handle| async {
298            let mut validity_builder = BitBufferMut::with_capacity(2);
299            validity_builder.append(true);
300            validity_builder.append(false);
301            let validity_boolean_buffer = validity_builder.freeze();
302            let validity = Validity::Array(
303                BoolArray::from_bit_buffer(validity_boolean_buffer.clone(), Validity::NonNullable)
304                    .into_array(),
305            );
306            let array = StructArray::try_new(
307                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
308                vec![
309                    buffer![1_u64, 2].into_array(),
310                    buffer![3_u64, 4].into_array(),
311                ],
312                2,
313                validity,
314            )
315            .unwrap();
316
317            let ctx = ArrayContext::empty();
318
319            // Write the array into a byte buffer.
320            let (layout, segments) = {
321                let segments = Arc::new(TestSegments::default());
322                let (ptr, eof) = SequenceId::root().split();
323                let layout = FlatLayoutStrategy::default()
324                    .write_stream(
325                        ctx,
326                        segments.clone(),
327                        array.to_array_stream().sequenced(ptr),
328                        eof,
329                        handle,
330                    )
331                    .await
332                    .unwrap();
333
334                (layout, segments)
335            };
336
337            // We should be able to read the array we just wrote.
338            let result: ArrayRef = layout
339                .new_reader("".into(), segments, &SESSION)
340                .unwrap()
341                .projection_evaluation(
342                    &(0..layout.row_count()),
343                    &root(),
344                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
345                )
346                .unwrap()
347                .await
348                .unwrap();
349
350            assert_eq!(
351                result.validity_mask().bit_buffer(),
352                AllOr::Some(&validity_boolean_buffer)
353            );
354            assert_eq!(
355                result
356                    .to_struct()
357                    .field_by_name("a")
358                    .unwrap()
359                    .to_primitive()
360                    .as_slice::<u64>(),
361                &[1, 2]
362            );
363            assert_eq!(
364                result
365                    .to_struct()
366                    .field_by_name("b")
367                    .unwrap()
368                    .to_primitive()
369                    .as_slice::<u64>(),
370                &[3, 4]
371            );
372        })
373    }
374}